Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0d2aafb2
Unverified
Commit
0d2aafb2
authored
Aug 23, 2022
by
Rostyslav Geyyer
Committed by
GitHub
Aug 23, 2022
Browse files
Merge branch 'develop' into lwpck-359_int4
parents
bd78cb4b
e0d8806c
Changes
51
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
430 additions
and
154 deletions
+430
-154
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+156
-137
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+5
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+9
-3
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+0
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+4
-9
include/ck/utility/functional.hpp
include/ck/utility/functional.hpp
+14
-0
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+2
-1
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+4
-1
profiler/include/profile_batched_gemm_gemm_impl.hpp
profiler/include/profile_batched_gemm_gemm_impl.hpp
+6
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
+109
-0
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
+121
-0
No files found.
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
0d2aafb2
...
...
@@ -12,166 +12,176 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
// For padding tensors without batch dimension
template
<
bool
PadM
,
bool
PadN
,
typename
TensorDesc_MRaw_NRaw
,
typename
MPerBlockType
,
typename
NPerBlockType
,
enable_if_t
<
TensorDesc_MRaw_NRaw
::
GetNumOfVisibleDimension
()
==
2
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
PadTensorDescriptor
(
const
TensorDesc_MRaw_NRaw
&
tensor_desc_mraw_nraw
,
MPerBlockType
MPerBlock
,
NPerBlockType
NPerBlock
)
{
const
auto
MRaw
=
tensor_desc_mraw_nraw
.
GetLength
(
Number
<
0
>
{});
const
auto
NRaw
=
tensor_desc_mraw_nraw
.
GetLength
(
Number
<
1
>
{});
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
MTransform
=
conditional_expr
<
PadM
>
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
MRaw
));
const
auto
NTransform
=
conditional_expr
<
PadN
>
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
NRaw
));
return
transform_tensor_descriptor
(
tensor_desc_mraw_nraw
,
make_tuple
(
MTransform
,
NTransform
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
// For padding tensors with batch dimension
template
<
bool
PadM
,
bool
PadN
,
typename
TensorDesc_GRaw_MRaw_NRaw
,
typename
MPerBlockType
,
typename
NPerBlockType
,
enable_if_t
<
TensorDesc_GRaw_MRaw_NRaw
::
GetNumOfVisibleDimension
()
==
3
,
bool
>
=
false
>
__host__
__device__
constexpr
auto
PadTensorDescriptor
(
const
TensorDesc_GRaw_MRaw_NRaw
&
tensor_desc_graw_mraw_nraw
,
MPerBlockType
MPerBlock
,
NPerBlockType
NPerBlock
)
{
const
auto
GRaw
=
tensor_desc_graw_mraw_nraw
.
GetLength
(
Number
<
0
>
{});
const
auto
MRaw
=
tensor_desc_graw_mraw_nraw
.
GetLength
(
Number
<
1
>
{});
const
auto
NRaw
=
tensor_desc_graw_mraw_nraw
.
GetLength
(
Number
<
2
>
{});
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
MTransform
=
conditional_expr
<
PadM
>
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
MRaw
));
const
auto
NTransform
=
conditional_expr
<
PadN
>
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
NRaw
));
return
transform_tensor_descriptor
(
tensor_desc_graw_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
GRaw
),
MTransform
,
NTransform
),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
}
// M/N/K/OPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
,
typename
OPerTileType
>
struct
GemmGemmPadder
{
// TODO: hard to scale; use mask instead
static
constexpr
bool
PadM
=
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
MOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadN
=
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
NOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadK
=
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
static
constexpr
bool
PadO
=
GemmSpec
==
GemmSpecialization
::
OPadding
||
GemmSpec
==
GemmSpecialization
::
MOPadding
||
GemmSpec
==
GemmSpecialization
::
NOPadding
||
GemmSpec
==
GemmSpecialization
::
KOPadding
||
GemmSpec
==
GemmSpecialization
::
MNOPadding
||
GemmSpec
==
GemmSpecialization
::
MKOPadding
||
GemmSpec
==
GemmSpecialization
::
NKOPadding
||
GemmSpec
==
GemmSpecialization
::
MNKOPadding
;
// A[M, K]
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
return
PadTensorDescriptor
<
PadM
,
PadK
>
(
a_desc_mraw_kraw
,
MPerTile_
,
KPerTile_
);
}
// B[K, N]
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
<
PadN
,
PadK
>
(
b_desc_nraw_kraw
,
NPerTile_
,
KPerTile_
);
}
// B1[Gemm1N, Gemm1K] = B1[O, N]
template
<
typename
B1Desc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadB1Descriptor_N_K
(
const
B1Desc_NRaw_KRaw
&
b1_desc_nraw_kraw
)
const
{
return
PadTensorDescriptor
<
PadO
,
PadN
>
(
b1_desc_nraw_kraw
,
OPerTile_
,
NPerTile_
);
}
// C[M, Gemm1N] = C[M, O]
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
return
PadTensorDescriptor
<
PadM
,
PadO
>
(
c_desc_mraw_nraw
,
MPerTile_
,
OPerTile_
);
}
MPerTileType
MPerTile_
;
NPerTileType
NPerTile_
;
KPerTileType
KPerTile_
;
OPerTileType
OPerTile_
;
};
// M/N/KPerTileType could be index_t or Number<>
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
Matrix
Padder
struct
Gemm
Padder
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
bool
PadM
=
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
bool
PadN
=
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
static
constexpr
bool
PadK
=
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
);
template
<
typename
ADesc_MRaw_KRaw
>
__host__
__device__
constexpr
auto
PadADescriptor_M_K
(
const
ADesc_MRaw_KRaw
&
a_desc_mraw_kraw
)
const
{
const
auto
MRaw
=
a_desc_mraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
a_desc_mraw_kraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
return
transform_tensor_descriptor
(
a_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or K
return
a_desc_mraw_kraw
;
}
return
PadTensorDescriptor
<
PadM
,
PadK
>
(
a_desc_mraw_kraw
,
MPerTile_
,
KPerTile_
);
}
template
<
typename
BDesc_NRaw_KRaw
>
__host__
__device__
constexpr
auto
PadBDescriptor_N_K
(
const
BDesc_NRaw_KRaw
&
b_desc_nraw_kraw
)
const
{
const
auto
NRaw
=
b_desc_nraw_kraw
.
GetLength
(
I0
);
const
auto
KRaw
=
b_desc_nraw_kraw
.
GetLength
(
I1
);
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
K
=
math
::
integer_divide_ceil
(
KRaw
,
KPerTile_
)
*
KPerTile_
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_pass_through_transform
(
KRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
return
transform_tensor_descriptor
(
b_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad N or K
return
b_desc_nraw_kraw
;
}
return
PadTensorDescriptor
<
PadN
,
PadK
>
(
b_desc_nraw_kraw
,
NPerTile_
,
KPerTile_
);
}
template
<
typename
CDesc_MRaw_NRaw
>
__host__
__device__
constexpr
auto
PadCDescriptor_M_N
(
const
CDesc_MRaw_NRaw
&
c_desc_mraw_nraw
)
const
{
const
auto
MRaw
=
c_desc_mraw_nraw
.
GetLength
(
I0
);
const
auto
NRaw
=
c_desc_mraw_nraw
.
GetLength
(
I1
);
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerTile_
)
*
MPerTile_
;
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerTile_
)
*
NPerTile_
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
NPad
=
N
-
NRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
MRaw
,
MPad
),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_desc_mraw_nraw
;
}
return
PadTensorDescriptor
<
PadM
,
PadN
>
(
c_desc_mraw_nraw
,
MPerTile_
,
NPerTile_
);
}
MPerTileType
MPerTile_
;
...
...
@@ -179,6 +189,15 @@ struct MatrixPadder
KPerTileType
KPerTile_
;
};
// Alias of GemmPadder; to deprecate
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
>
struct
MatrixPadder
:
public
GemmPadder
<
GemmSpec
,
MPerTileType
,
NPerTileType
,
KPerTileType
>
{
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
0d2aafb2
...
...
@@ -189,7 +189,11 @@ struct AddAddFastGelu
template
<
typename
T
>
static
inline
constexpr
bool
is_valid_param_type_v
=
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bhalf_t
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
;
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
std
::
is_same_v
<
T
,
ck
::
int4_t
>
#endif
;
template
<
typename
E
,
typename
C
,
typename
D0
,
typename
D1
>
__host__
__device__
constexpr
void
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
0d2aafb2
...
...
@@ -200,7 +200,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
B1GridDesc_BK0_N_BK1
&
b1_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
,
const
std
::
vector
<
index_t
>&
lengths_m_n_k_o
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
...
...
@@ -216,6 +217,13 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
}
// K is rounded to nearest multiples of K1 during tensor transformation so instead get KRaw
const
auto
KRaw
=
lengths_m_n_k_o
[
2
];
if
(
!
(
KRaw
%
AK1
==
0
&&
KRaw
%
BK1
==
0
))
{
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
&&
Gemm1N
%
Gemm1NPerBlock
==
0
))
{
...
...
@@ -241,8 +249,6 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
return
false
;
}
assert
(
num_gemm1_k_outer_loop
*
num_gemm1_k_inner_loop
==
N
/
Gemm1KPerBlock
);
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
0d2aafb2
...
...
@@ -245,8 +245,6 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
return
false
;
}
assert
(
num_gemm1_k_outer_loop
*
num_gemm1_k_inner_loop
==
N
/
Gemm1KPerBlock
);
if
(
!
block_2_ctile_map
.
CheckValidity
(
c_grid_desc_m_n
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
0d2aafb2
...
...
@@ -53,7 +53,7 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
static_cast
<
void
*>
(
p_shared_block
)
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
...
@@ -270,7 +270,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -463,8 +463,8 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
FloatAB
*
p_a_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
)
;
FloatAB
*
p_b_block
=
static_cast
<
FloatAB
*>
(
p_shared_block
)
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
...
...
@@ -547,11 +547,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
static_cast
<
FloatC
*>
(
p_shared_block
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
static_assert
(
M1
==
MWave
,
""
);
static_assert
(
N1
==
NWave
,
""
);
static_assert
(
M2
*
M3
*
M4
==
MPerXDL
,
""
);
static_assert
(
N2
==
NPerXDL
,
""
);
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
...
...
include/ck/utility/functional.hpp
View file @
0d2aafb2
...
...
@@ -114,4 +114,18 @@ struct conditional<false, X, Y>
template
<
bool
predicate
,
class
X
,
class
Y
>
using
conditional_t
=
typename
conditional
<
predicate
,
X
,
Y
>::
type
;
// z = predicate ? x : y
template
<
bool
predicate
,
typename
X
,
typename
Y
>
constexpr
auto
conditional_expr
(
X
&&
x
,
Y
&&
y
)
{
if
constexpr
(
predicate
)
{
return
std
::
forward
<
X
>
(
x
);
}
else
{
return
std
::
forward
<
Y
>
(
y
);
}
}
}
// namespace ck
library/include/ck/library/utility/host_tensor.hpp
View file @
0d2aafb2
...
...
@@ -271,7 +271,8 @@ struct Tensor
~
Tensor
()
=
default
;
Tensor
&
operator
=
(
const
Tensor
&
)
=
default
;
Tensor
&
operator
=
(
Tensor
&&
)
=
default
;
Tensor
&
operator
=
(
Tensor
&&
)
=
default
;
template
<
typename
FromT
>
explicit
Tensor
(
const
Tensor
<
FromT
>&
other
)
:
Tensor
(
other
.
template
CopyAsType
<
T
>())
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_gemm/device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
0d2aafb2
...
...
@@ -26,6 +26,7 @@ using S = ck::Sequence<Is...>;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
using
device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
...
...
@@ -37,7 +38,9 @@ using device_batched_gemm_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_inst
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
,
// Padded fallback kernel
DeviceBatchedGemmGemm_Xdl_CShuffle
<
Row
,
Col
,
Row
,
Row
,
F16
,
F16
,
F16
,
F16
,
F32
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmPadded
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
>
// clang-format on
>
;
...
...
profiler/include/profile_batched_gemm_gemm_impl.hpp
View file @
0d2aafb2
...
...
@@ -195,6 +195,12 @@ bool profile_batched_gemm_gemm_impl(bool do_verification,
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
// early fail when no instances are found
if
(
op_ptrs
.
size
()
==
0
)
{
return
false
;
}
if
(
do_verification
)
{
auto
ref_gemm0
=
ReferenceGemm0Instance
{};
...
...
test/batched_gemm_gemm/test_batched_gemm_gemm_fp16.cpp
View file @
0d2aafb2
...
...
@@ -19,6 +19,74 @@ TYPED_TEST_SUITE(TestBatchedGemmGemmFP16, KernelTypes);
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16
)
{
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
136
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
136
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
40
,
128
,
1
},
{
128
,
128
,
136
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_PadO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
136
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddM
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
129
,
128
,
32
,
128
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddN
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
129
,
32
,
128
,
1
},
};
this
->
Run
();
}
// Currently expected that no kernels can support this case
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddK
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
33
,
128
,
1
},
{
128
,
128
,
129
,
128
,
1
},
};
this
->
Run
();
}
// If kernel B1Layout is RowMajor, expect not to support odd O size
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
Test_FP16_OddO
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
{
128
,
128
,
32
,
129
,
1
},
};
this
->
Run
();
}
TYPED_TEST
(
TestBatchedGemmGemmFP16
,
DISABLED_Bench_FP16
)
{
this
->
lengths_
=
std
::
vector
<
std
::
vector
<
int
>>
{
...
...
@@ -37,3 +105,44 @@ TYPED_TEST(TestBatchedGemmGemmFP16, DISABLED_Bench_FP16)
this
->
verify_
=
false
;
this
->
Run
();
}
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMatch
)
{
int
P
=
120
;
// requires padding
int
Q
=
128
;
// do not require padding
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
Q
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
OPadding
>
{}.
IsSupported
(
Q
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MOPadding
>
{}.
IsSupported
(
P
,
Q
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NOPadding
>
{}.
IsSupported
(
Q
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
KOPadding
>
{}.
IsSupported
(
Q
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNOPadding
>
{}.
IsSupported
(
P
,
P
,
Q
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MKOPadding
>
{}.
IsSupported
(
P
,
Q
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
NKOPadding
>
{}.
IsSupported
(
Q
,
P
,
P
,
P
));
EXPECT_TRUE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
P
,
P
,
P
,
P
));
// clang-format on
}
TEST
(
TestBatchedGemmGemmInterface
,
GemmSpecializationSizeMismatch
)
{
// IsSupported(M, N, K, O)
// clang-format off
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
Default
>
{}.
IsSupported
(
128
,
128
,
120
,
128
));
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
120
));
// Kernel can't support odd K because K must be integer multiples of K1 values of either A or B
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
129
,
128
));
// Kernel can't support odd O size because it must satisfy SizeO % B1SrcScalarPerVector == 0
EXPECT_FALSE
(
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
<
GemmSpecialization
::
MNKOPadding
>
{}.
IsSupported
(
128
,
128
,
128
,
129
));
// clang-format on
}
test/batched_gemm_gemm/test_batched_gemm_gemm_util.hpp
View file @
0d2aafb2
...
...
@@ -4,8 +4,12 @@
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_gemm_xdl_cshuffle.hpp"
#include "profiler/include/profile_batched_gemm_gemm_impl.hpp"
using
ck
::
tensor_operation
::
device
::
GemmSpecialization
;
template
<
ck
::
index_t
N
>
using
I
=
ck
::
Number
<
N
>
;
...
...
@@ -66,3 +70,120 @@ struct TestBatchedGemmGemm : public ::testing::Test
}
}
};
template
<
GemmSpecialization
GemmSpec
>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B1Layout
=
Row
;
using
CLayout
=
Row
;
using
ADataType
=
F16
;
using
B0DataType
=
F16
;
using
B1DataType
=
F16
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
F16
;
using
AElementOp
=
PassThrough
;
using
B0ElementOp
=
PassThrough
;
using
Acc0ElementOp
=
PassThrough
;
using
B1ElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
// static constexpr auto GemmSpec = std::tuple_element_t<0, Tuple>::value;
using
DeviceGemmGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmGemm_Xdl_CShuffle
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
B0ElementOp
,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
GemmSpec
,
1
,
256
,
128
,
// MPerBlock
128
,
// NPerBlock
32
,
// KPerBlock
128
,
// Gemm1NPerBlock
32
,
// Gemm1KPerBlock
8
,
// AK1
8
,
// BK1
2
,
// B1K1
32
,
// MPerXDL
32
,
// NPerXDL
1
,
// MXdlPerWave
4
,
// NXdlPerWave
4
,
// Gemm1NXdlPerWave
S
<
4
,
64
,
1
>
,
// ABlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
// BBlockTransfer
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
// CShuffleMXdlPerWavePerShuffle
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
>
;
// CShuffleBlockTransferScalarPerVector_NPerBlock
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
auto
gemm
=
DeviceGemmGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
nullptr
),
static_cast
<
B0DataType
*>
(
nullptr
),
static_cast
<
B1DataType
*>
(
nullptr
),
static_cast
<
CDataType
*>
(
nullptr
),
M
,
N
,
K
,
O
,
0
,
// BatchCount
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideC
0
,
// BatchStrideA
0
,
// BatchStrideB0
0
,
// BatchStrideB1
0
,
// BatchStrideC
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
return
gemm
.
IsSupportedArgument
(
argument
);
}
};
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment