Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e95cb82b
Commit
e95cb82b
authored
Jan 31, 2025
by
jefyang1
Browse files
Fix gemm gemm on gfx950
parent
c38163cd
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
61 additions
and
33 deletions
+61
-33
example/31_batched_gemm_gemm/CMakeLists.txt
example/31_batched_gemm_gemm/CMakeLists.txt
+1
-1
example/41_grouped_conv_conv_fwd/CMakeLists.txt
example/41_grouped_conv_conv_fwd/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
...n/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
+0
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+12
-5
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
..._batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
+2
-6
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+2
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+43
-8
No files found.
example/31_batched_gemm_gemm/CMakeLists.txt
View file @
e95cb82b
...
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
...
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int4 batched_gemm_gemm_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx95"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
add_example_executable
(
example_batched_gemm_gemm_xdl_int8 batched_gemm_gemm_xdl_int8.cpp
)
endif
()
endif
()
example/41_grouped_conv_conv_fwd/CMakeLists.txt
View file @
e95cb82b
...
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
...
@@ -5,6 +5,6 @@ if(USE_BITINT_EXTENSION_INT4)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int4 grouped_conv_conv_fwd_xdl_int4.cpp
)
endif
(
USE_BITINT_EXTENSION_INT4
)
endif
(
USE_BITINT_EXTENSION_INT4
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
if
(
NOT GPU_TARGETS MATCHES
"gfx94"
AND NOT GPU_TARGETS MATCHES
"gfx95"
AND NOT GPU_TARGETS MATCHES
"gfx1"
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
add_example_executable
(
example_grouped_conv_conv_fwd_xdl_int8 grouped_conv_conv_fwd_xdl_int8.cpp
)
endif
()
endif
()
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_xdl_cshuffle_v1.hpp
View file @
e95cb82b
...
@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
...
@@ -608,14 +608,8 @@ struct GridwiseBatchedGemmGemm_Xdl_CShuffle
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
#else
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
#endif
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
e95cb82b
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -856,11 +856,18 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
static_cast
<
A0B0B1DataType
*>
(
p_shared
)
+
SharedMemTrait
::
b1_block_space_offset
,
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
index_t
Gemm1KPack
=
math
::
max
(
// selected_mfma.group_size or B1K1 <= Gemm1KPack <= selected_mfma.group_size
math
::
lcm
(
// selected_mfma.k_per_blk <= Gemm1KPack
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
,
//
B1K1
),
// Following similar rationale behind Gemm0KPack, let Gemm1KPack be the lowest common
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
k_per_blk
);
// multiples of A1K1 (predetermined by selected_mfma.group_size) and B1K1. But in this case
// Gemm1KPack can't be higher than A1K1 itself because A1 matrix is distributed in VGPRs
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
A0B0B1DataType
,
Gemm0MPerXdl
,
Gemm0NPerXdl
>::
selected_mfma
.
group_size
;
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
e95cb82b
...
@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
...
@@ -773,14 +773,10 @@ struct GridwiseBatchedGemmMultipleDSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
#else
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
#endif
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
View file @
e95cb82b
...
@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -628,14 +628,10 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// with 'group_size' amount of contiguous elements. Having Gemm1KPack greater than A1K1 will
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// cause mismatch in summation index for example c[0:7] = a1[[0:3, 8:11]] * b1[0:7].
// therefore we may just as well assign Gemm1KPack = group_size
// therefore we may just as well assign Gemm1KPack = group_size
#if defined(__gfx950__)
// TODO: fix logic for gfx950 as it's temporary hack for passing compiling
constexpr
index_t
Gemm1KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
#else
constexpr
index_t
Gemm1KPack
=
constexpr
index_t
Gemm1KPack
=
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
group_size
;
#endif
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
auto
gemm1_blockwise_gemm
=
BlockwiseGemmXdlops_v2
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
,
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
e95cb82b
...
@@ -880,13 +880,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
...
@@ -880,13 +880,15 @@ struct mfma_type<MfmaInstr::mfma_scale_f32_16x16x128f8f6f4>
template
<
typename
base_type
,
template
<
typename
base_type
,
index_t
MPerXdlops
,
index_t
MPerXdlops
,
index_t
NPerXdlops
,
index_t
NPerXdlops
,
typename
additional_type
=
base_type
>
typename
additional_type
=
base_type
,
bool
is_single_rate_mfma
=
false
>
struct
MfmaSelector
struct
MfmaSelector
{
{
template
<
typename
base_type_
,
template
<
typename
base_type_
,
index_t
MPerXdlops_
,
index_t
MPerXdlops_
,
index_t
NPerXdlops_
,
index_t
NPerXdlops_
,
typename
additional_type_
=
base_type_
>
typename
additional_type_
=
base_type_
,
bool
is_single_rate_mfma_
=
false
>
static
constexpr
auto
GetMfma
();
static
constexpr
auto
GetMfma
();
template
<
>
template
<
>
...
@@ -950,7 +952,7 @@ struct MfmaSelector
...
@@ -950,7 +952,7 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
false
>
()
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16f16
;
return
MfmaInstr
::
mfma_f32_32x32x16f16
;
...
@@ -958,9 +960,14 @@ struct MfmaSelector
...
@@ -958,9 +960,14 @@ struct MfmaSelector
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
#endif
#endif
}
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
32
,
32
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
false
>
()
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32f16
;
return
MfmaInstr
::
mfma_f32_16x16x32f16
;
...
@@ -969,6 +976,12 @@ struct MfmaSelector
...
@@ -969,6 +976,12 @@ struct MfmaSelector
#endif
#endif
}
}
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
16
,
half_t
,
true
>
()
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
{
...
@@ -988,7 +1001,7 @@ struct MfmaSelector
...
@@ -988,7 +1001,7 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
false
>
()
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_32x32x16bf16
;
return
MfmaInstr
::
mfma_f32_32x32x16bf16
;
...
@@ -1000,7 +1013,17 @@ struct MfmaSelector
...
@@ -1000,7 +1013,17 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_32x32x4bf16
;
#endif
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
false
>
()
{
{
#if defined(__gfx950__)
#if defined(__gfx950__)
return
MfmaInstr
::
mfma_f32_16x16x32bf16
;
return
MfmaInstr
::
mfma_f32_16x16x32bf16
;
...
@@ -1011,6 +1034,16 @@ struct MfmaSelector
...
@@ -1011,6 +1034,16 @@ struct MfmaSelector
#endif
#endif
}
}
template
<
>
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
,
bhalf_t
,
true
>
()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
#else
return
MfmaInstr
::
mfma_f32_16x16x8bf16
;
#endif
}
#if defined(__gfx950__)
#if defined(__gfx950__)
template
<
>
template
<
>
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
...
@@ -1095,7 +1128,7 @@ struct MfmaSelector
...
@@ -1095,7 +1128,7 @@ struct MfmaSelector
}
}
static
constexpr
auto
selected_mfma
=
static
constexpr
auto
selected_mfma
=
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
()
>
{};
mfma_type
<
GetMfma
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
is_single_rate_mfma
>
()
>
{};
__host__
__device__
constexpr
MfmaSelector
()
__host__
__device__
constexpr
MfmaSelector
()
{
{
...
@@ -1397,7 +1430,9 @@ struct XdlopsGemm
...
@@ -1397,7 +1430,9 @@ struct XdlopsGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
}
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
>
{};
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static
constexpr
auto
mfma
=
MfmaSelector
<
base_type
,
MPerXdlops
,
NPerXdlops
,
additional_type
,
((
is_same
<
base_type
,
half_t
>::
value
||
is_same
<
base_type
,
bhalf_t
>::
value
)
&&
KPack
<=
4
)
?
true
:
false
>
{};
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
static
constexpr
auto
mfma_instr
=
mfma
.
selected_mfma
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment