Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e675b5c3
Commit
e675b5c3
authored
Feb 10, 2023
by
guangzlu
Browse files
can compile now
parent
26cc4721
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
37 additions
and
13 deletions
+37
-13
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
...rouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
+3
-0
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
...emm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
+7
-3
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
...n/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
..._grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
+21
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
+4
-2
No files found.
example/32_batched_gemm_scale_softmax_gemm/grouped_gemm_scale_softmax_gemm_permute_train_xdl_fp16.cpp
View file @
e675b5c3
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
...
@@ -33,6 +33,7 @@ using S = ck::Sequence<Is...>;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
F32
=
float
;
using
U16
=
unsigned
short
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
...
@@ -42,6 +43,7 @@ using B1DataType = F16;
using
AccDataType
=
F32
;
using
AccDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CShuffleDataType
=
F32
;
using
CDataType
=
F16
;
using
CDataType
=
F16
;
using
ZDataType
=
U16
;
using
LSEDataType
=
F32
;
using
LSEDataType
=
F32
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc0BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
using
Acc1BiasDataType
=
ck
::
Tuple
<>
;
...
@@ -78,6 +80,7 @@ using DeviceGemmInstance =
...
@@ -78,6 +80,7 @@ using DeviceGemmInstance =
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ZDataType
,
LSEDataType
,
LSEDataType
,
Acc0BiasDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
Acc1BiasDataType
,
...
...
example/32_batched_gemm_scale_softmax_gemm/run_grouped_gemm_scale_softmax_gemm_permute_train.inc
View file @
e675b5c3
...
@@ -56,6 +56,7 @@ int run(int argc, char* argv[])
...
@@ -56,6 +56,7 @@ int run(int argc, char* argv[])
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B0DataType
>>
b0_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
B1DataType
>>
b1_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
CDataType
>>
c_tensors
;
std
::
vector
<
Tensor
<
ZDataType
>>
z_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
std
::
vector
<
Tensor
<
LSEDataType
>>
lse_tensors
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
using
DeviceMemPtr
=
std
::
unique_ptr
<
DeviceMem
>
;
...
@@ -63,6 +64,7 @@ int run(int argc, char* argv[])
...
@@ -63,6 +64,7 @@ int run(int argc, char* argv[])
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b0_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
b1_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
c_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
z_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
vector
<
DeviceMemPtr
>
lse_tensors_device
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
std
::
size_t
flop
=
0
,
num_byte
=
0
;
...
@@ -103,8 +105,8 @@ int run(int argc, char* argv[])
...
@@ -103,8 +105,8 @@ int run(int argc, char* argv[])
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
O
,
O
,
G1
*
O
,
1
}
// C layout [G0, M, G1, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
O
,
M
*
O
,
O
,
1
};
// C layout [G0, G1, M, O]
std
::
vector
<
ck
::
index_t
>
z_gs_ms_
o
s_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_
n
s_lengths
{
G0
,
G1
,
M
,
N
};
std
::
vector
<
ck
::
index_t
>
z_gs_ms_
o
s_strides
=
std
::
vector
<
ck
::
index_t
>
z_gs_ms_
n
s_strides
=
output_permute
output_permute
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
?
std
::
vector
<
ck
::
index_t
>
{
M
*
G1
*
N
,
N
,
G1
*
N
,
1
}
// Z layout [G0, M, G1, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
:
std
::
vector
<
ck
::
index_t
>
{
G1
*
M
*
N
,
M
*
N
,
N
,
1
};
// Z layout [G0, G1, M, N]
...
@@ -134,8 +136,8 @@ int run(int argc, char* argv[])
...
@@ -134,8 +136,8 @@ int run(int argc, char* argv[])
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B0DataType
>
b0_gs_ns_ks
(
b0_gs_ns_ks_lengths
,
b0_gs_ns_ks_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
B1DataType
>
b1_gs_os_ns
(
b1_gs_os_ns_lengths
,
b1_gs_os_ns_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
CDataType
>
c_gs_ms_os_device_result
(
c_gs_ms_os_lengths
,
c_gs_ms_os_strides
);
Tensor
<
ZDataType
>
z_gs_ms_ns
(
z_gs_ms_ns_lengths
,
z_gs_ms_ns_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
Tensor
<
LSEDataType
>
lse_gs_ms_device_result
(
lse_gs_ms_lengths
,
lse_gs_ms_strides
);
int
Batch
=
G0
*
G1
;
int
Batch
=
G0
*
G1
;
...
@@ -150,6 +152,7 @@ int run(int argc, char* argv[])
...
@@ -150,6 +152,7 @@ int run(int argc, char* argv[])
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b0_gs_ns_ks["
<<
i
<<
"]: "
<<
b0_gs_ns_ks
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"b1_gs_os_ns["
<<
i
<<
"]: "
<<
b1_gs_os_ns
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"c_gs_ms_os["
<<
i
<<
"]: "
<<
c_gs_ms_os_device_result
.
mDesc
<<
", "
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
"lse_gs_ms_os["
<<
i
<<
"]: "
<<
lse_gs_ms_device_result
.
mDesc
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
@@ -182,6 +185,7 @@ int run(int argc, char* argv[])
...
@@ -182,6 +185,7 @@ int run(int argc, char* argv[])
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b0_tensors
.
push_back
(
b0_gs_ns_ks
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
b1_tensors
.
push_back
(
b1_gs_os_ns
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
c_tensors
.
push_back
(
c_gs_ms_os_device_result
);
z_tensors
.
push_back
(
c_gs_ms_os_device_result
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
lse_tensors
.
push_back
(
lse_gs_ms_device_result
);
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
a_tensors_device
.
emplace_back
(
std
::
make_unique
<
DeviceMem
>
(
...
...
include/ck/tensor_operation/gpu/device/device_grouped_gemm_softmax_gemm_permute.hpp
View file @
e675b5c3
...
@@ -105,8 +105,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
...
@@ -105,8 +105,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermuteTrain : public BaseOperator
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_lengths
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
c_gs_ms_os_strides
;
std
::
vector
<
index_t
>
z_gs_ms_
o
s_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_
n
s_lengths
;
std
::
vector
<
index_t
>
z_gs_ms_
o
s_strides
;
std
::
vector
<
index_t
>
z_gs_ms_
n
s_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_lengths
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
std
::
vector
<
index_t
>
lse_gs_ms_strides
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_softmax_gemm_permute_train_xdl_cshuffle.hpp
View file @
e675b5c3
...
@@ -97,12 +97,17 @@ __global__ void
...
@@ -97,12 +97,17 @@ __global__ void
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
const
long_index_t
lse_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
arg_ptr
[
group_id
].
compute_base_ptr_of_batch_
.
GetLSEBasePtr
(
g_idx
)));
//unsigned short* p_z_grid_in = //
// (arg_ptr[group_id].p_z_grid_ == nullptr ? nullptr
// : arg_ptr[group_id].p_z_grid_ + z_batch_offset);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
IsDropout
>(
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_a_grid_
+
a_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b_grid_
+
b_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_b1_grid_
+
b1_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_c_grid_
+
c_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_z_grid_
==
nullptr
?
nullptr
:
arg_ptr
[
group_id
].
p_z_grid_
+
z_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
arg_ptr
[
group_id
].
p_lse_grid_
+
lse_batch_offset
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
...
@@ -114,7 +119,7 @@ __global__ void
...
@@ -114,7 +119,7 @@ __global__ void
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
b1_grid_desc_bk0_n_bk1_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
////////
arg_ptr
[
group_id
].
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
_
,
////////
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
lse_grid_desc_m_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
block_2_ctile_map_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
arg_ptr
[
group_id
].
c0_matrix_mask_
,
...
@@ -411,7 +416,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -411,7 +416,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
BGridDesc_G_N_K
b_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
B1GridDesc_G_N_K
b1_grid_desc_g_n_k_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
CGridDesc_G_M_N
c_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
&
z_grid_desc_g_m_n_
;
ZGridDesc_G_M_N
z_grid_desc_g_m_n_
;
index_t
BatchStrideLSE_
;
index_t
BatchStrideLSE_
;
};
};
...
@@ -490,7 +495,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -490,7 +495,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
const
B1DataType
*
p_b1_grid_
;
const
B1DataType
*
p_b1_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
ZDataType
*
p_z_grid
;
ZDataType
*
p_z_grid
_
;
LSEDataType
*
p_lse_grid_
;
LSEDataType
*
p_lse_grid_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
...
@@ -499,6 +504,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -499,6 +504,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
ZGridDesc_M_N
z_grid_desc_m_n_
;
LSEGridDesc_M
lse_grid_desc_m_
;
LSEGridDesc_M
lse_grid_desc_m_
;
...
@@ -592,7 +599,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -592,7 +599,7 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
const
auto
c_grid_desc_m_n
=
Transform
::
MakeCGridDescriptor_M_N
(
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
problem_desc
.
c_gs_ms_os_lengths
,
problem_desc
.
c_gs_ms_os_strides
);
const
auto
z_grid_desc_m_n
=
MakeZGridDescriptor_M_N
(
const
auto
z_grid_desc_m_n
=
MakeZGridDescriptor_M_N
(
problem_desc
.
z_gs_ms_
o
s_lengths
,
problem_desc
.
z_gs_ms_
o
s_strides
);
problem_desc
.
z_gs_ms_
n
s_lengths
,
problem_desc
.
z_gs_ms_
n
s_strides
);
const
auto
lse_grid_desc_m
=
const
auto
lse_grid_desc_m
=
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
DeviceOp
::
MakeLSEGridDescriptor_M
(
problem_desc
.
lse_gs_ms_lengths
[
NumDimG
]);
...
@@ -610,6 +617,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -610,6 +617,13 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
c_grid_desc_m_n
);
//typename GridwiseGemm::ZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
// z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5;
auto
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
z_grid_desc_m_n
);
const
index_t
BlockStart
=
grid_size_
;
const
index_t
BlockStart
=
grid_size_
;
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
const
auto
block_2_ctile_map
=
Block2CTileMap
(
c_grid_desc_m_n
,
BlockStart
);
...
@@ -654,7 +668,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
...
@@ -654,7 +668,8 @@ struct DeviceGroupedGemmSoftmaxGemmPermute_Train_Xdl_CShuffle
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
z_grid_desc_g_m_n
,
z_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
z_grid_desc_m_n
,
lse_grid_desc_m
,
lse_grid_desc_m
,
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
block_2_ctile_map
.
CalculateGridSize
(
c_grid_desc_m_n
),
compute_base_ptr_of_batch
,
compute_base_ptr_of_batch
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v2.hpp
View file @
e675b5c3
...
@@ -98,6 +98,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -98,6 +98,8 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
static
constexpr
auto
WaveSize
=
64
;
// K1 should be Number<...>
// K1 should be Number<...>
// Gemm0
// Gemm0
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
...
@@ -124,7 +126,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -124,7 +126,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
M
=
z_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
N
=
z_grid_desc_m_n
.
GetLength
(
I1
);
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
@@ -140,7 +142,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
...
@@ -140,7 +142,7 @@ struct GridwiseBatchedGemmSoftmaxGemmTrain_Xdl_CShuffle
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
MakeZGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
const
index_t
M
,
const
index_t
N
)
////=> for z use
{
{
constexpr
auto
mfma
=
MfmaSelector
<
DataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
mfma
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N3
=
mfma
.
num_groups_per_blk
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N4
=
mfma
.
num_input_blks
;
constexpr
auto
N5
=
mfma
.
group_size
;
constexpr
auto
N5
=
mfma
.
group_size
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment