Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
26d5174e
"...resnet50_tensorflow.git" did not exist on "2f43cff2b72e9a5bee26d31e4e8af3087a5618e1"
Commit
26d5174e
authored
Nov 26, 2024
by
aska-0096
Browse files
update instance and lds layout strategy
parent
ea90b01f
Changes
8
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
412 additions
and
15 deletions
+412
-15
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+8
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+211
-4
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
.../tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
+34
-4
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+31
-3
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp
...y_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp
...y_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp
..._xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp
+32
-0
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp
..._xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp
+32
-0
No files found.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
26d5174e
...
@@ -615,11 +615,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -615,11 +615,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire loop to hide it in v4.
// it may give you some benefit from less valu in compute address
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
K
PerBlock
+
ABlockLdsExtraM
>
{}
,
I1
));
make_tuple
(
Number
<
M
PerBlock
>
{}
*
AK1Number
,
AK1Number
,
I1
));
}
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
// in some cases.
...
@@ -752,11 +754,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -752,11 +754,13 @@ struct GridwiseGemm_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire loop to hide it in v4.
// it may give you some benefit from less valu in compute address
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1
Number
,
Number
<
K
PerBlock
+
BBlockLdsExtraN
>
{},
I1
));
make_tuple
(
Number
<
N
PerBlock
+
BBlockLdsExtraN
>
{}
*
BK1Number
,
BK1Number
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
26d5174e
...
@@ -676,11 +676,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -676,11 +676,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
if
constexpr
(
ABlockLdsExtraM
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
K
PerBlock
+
ABlockLdsExtraM
>
{}
,
I1
));
make_tuple
(
Number
<
M
PerBlock
>
{}
*
AK1Number
,
AK1Number
,
I1
));
}
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
// in some cases.
...
@@ -813,11 +815,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -813,11 +815,13 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
{
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
if
constexpr
(
BBlockLdsExtraN
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
// bank conflict when writting the data into LDS, but don't worry, we have whole entire
// loop to hide it in v4. it may give you some benefit from less valu in compute address
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1
Number
,
Number
<
K
PerBlock
+
BBlockLdsExtraN
>
{},
I1
));
make_tuple
(
Number
<
N
PerBlock
+
BBlockLdsExtraN
>
{}
*
BK1Number
,
BK1Number
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
...
@@ -1216,6 +1220,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1216,6 +1220,38 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
}
__device__
static
constexpr
auto
EpilogueScheduler
()
{
constexpr
auto
epilogue_tile
=
MPerBlock
*
NPerBlock
*
CShuffleMXdlPerWavePerShuffle
*
CShuffleNXdlPerWavePerShuffle
/
(
MXdlPerWave
*
NXdlPerWave
);
constexpr
auto
num_mfma_inst
=
BlockwiseGemmPipe
::
HotLoopInstList
::
C_MFMA_Inst_Num
*
CShuffleMXdlPerWavePerShuffle
*
CShuffleNXdlPerWavePerShuffle
/
(
MXdlPerWave
*
NXdlPerWave
);
constexpr
auto
num_ds_write_inst
=
epilogue_tile
/
BlockSize
;
// DefaultMFMA, per-element write
constexpr
auto
num_ds_read_inst
=
epilogue_tile
/
BlockSize
/
CShuffleBlockTransferScalarPerVector_NPerBlock
;
constexpr
auto
num_buffer_store_inst
=
num_ds_read_inst
;
// MFMA:ds_write=1:2
constexpr
auto
num_ds_write_issue
=
num_ds_write_inst
/
2
;
constexpr
auto
num_mfma_block_sync
=
(
num_mfma_inst
-
num_ds_write_issue
)
/
2
;
constexpr
auto
mfma_ds_write_rate
=
MXdlPerWave
==
16
?
2
:
4
;
// Hide ds_write issue latency
static_for
<
0
,
num_ds_write_issue
,
1
>
{}([
&
](
auto
i
)
{
ignore
=
i
;
__builtin_amdgcn_sched_group_barrier
(
0x008
,
1
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x200
,
mfma_ds_write_rate
,
0
);
// DS write
});
// Hide block_sync + ds_read latency
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_block_sync
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x100
,
num_ds_read_inst
,
0
);
// DS read
// Hide block_sync latency
__builtin_amdgcn_sched_group_barrier
(
0x008
,
num_mfma_block_sync
,
0
);
// MFMA
__builtin_amdgcn_sched_group_barrier
(
0x040
,
num_buffer_store_inst
,
0
);
// VMEM write
}
// return block_id to C matrix tile idx (m0, n0) mapping
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
// if arch = gfx942
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
using
Block2CTileMapDefault
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
...
@@ -1393,6 +1429,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1393,6 +1429,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
constexpr
auto
a_thread_desc
=
blockwise_gemm_pipeline
.
a_thread_desc_
;
constexpr
auto
b_thread_desc
=
blockwise_gemm_pipeline
.
b_thread_desc_
;
constexpr
auto
c_thread_desc
=
blockwise_gemm_pipeline
.
c_thread_desc_
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
a_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
b_thread_desc
.
GetElementSpaceSize
());
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
...
@@ -1410,10 +1455,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1410,10 +1455,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_buf
,
b_block_buf
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
a_thread_buf
,
b_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
// Last block MFMA
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
constexpr
auto
KRepeat
=
blockwise_gemm_pipeline
.
KRepeat
;
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -1573,6 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1573,6 +1624,9 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
const
auto
EGlobalMemoryDataOperation
=
CGlobalMemoryDataOperation
;
// C: LDS -> VGPR
// D: Global -> VGPR
// E: =Epilogue(C, D), VGPR -> Global
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
auto
cde_block_copy_lds_and_global
=
ThreadGroupTensorSliceTransfer_v7r3
<
ThisThreadBlock
,
ThisThreadBlock
,
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
decltype
(
container_concat
(
make_tuple
(
CShuffleDataType
{}),
DsDataType
{})),
...
@@ -1631,10 +1685,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1631,10 +1685,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
shuffle_m0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
0
>
{}];
constexpr
auto
shuffle_n0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
1
>
{}];
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
@@ -1668,6 +1789,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1668,6 +1789,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
EpilogueScheduler
();
}
}
});
});
}
}
...
@@ -1860,6 +1983,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1860,6 +1983,15 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
constexpr
auto
a_thread_desc
=
blockwise_gemm_pipeline
.
a_thread_desc_
;
constexpr
auto
b_thread_desc
=
blockwise_gemm_pipeline
.
b_thread_desc_
;
constexpr
auto
c_thread_desc
=
blockwise_gemm_pipeline
.
c_thread_desc_
;
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
a_thread_desc
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeTypeA
>
(
b_thread_desc
.
GetElementSpaceSize
());
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
KPerBlock
);
KPerBlock
);
...
@@ -1877,10 +2009,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1877,10 +2009,16 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
b_block_bufs
,
b_block_bufs
,
b_block_slice_copy_step
,
b_block_slice_copy_step
,
c_thread_buf
,
c_thread_buf
,
a_thread_buf
,
b_thread_buf
,
num_k_block_main_loop
);
num_k_block_main_loop
);
// shuffle C and write out
// shuffle C and write out
{
{
// Last block MFMA
auto
xdlops_gemm
=
blockwise_gemm_pipeline
.
xdlops_gemm
;
constexpr
auto
KRepeat
=
blockwise_gemm_pipeline
.
KRepeat
;
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -2098,10 +2236,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2098,10 +2236,77 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
// make sure it's safe to write to LDS
block_sync_lds
();
block_sync_lds
();
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
shuffle_m0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
0
>
{}];
constexpr
auto
shuffle_n0
=
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
+
Number
<
1
>
{})[
Number
<
1
>
{}];
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
// each thread write its data from VGPR to LDS
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
...
@@ -2135,6 +2340,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2135,6 +2340,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
EpilogueScheduler
();
}
}
});
});
}
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_multiply_multiply.hpp
View file @
26d5174e
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
...
@@ -17,7 +17,7 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8))
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
_part1
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -30,7 +30,33 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
...
@@ -30,7 +30,33 @@ void add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_inst
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
);
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Col
,
Tuple
<
Row
,
Col
>
,
Tuple
<
Row
,
Col
>
,
...
@@ -221,9 +247,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
...
@@ -221,9 +247,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGemmMu
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
CLayout
,
Row
>
)
is_same_v
<
CLayout
,
Row
>
)
{
{
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1
(
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances
_part2
(
op_ptrs
);
op_ptrs
);
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_v1_default_instances
(
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
26d5174e
This diff is collapsed.
Click to expand it.
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part1.cpp
0 → 100644
View file @
26d5174e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part1
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instance_part2.cpp
0 → 100644
View file @
26d5174e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_default_instances_part2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2
<
GemmDefault
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part1.cpp
0 → 100644
View file @
26d5174e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part1
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part1
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instance_part2.cpp
0 → 100644
View file @
26d5174e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_kpadding_instances_part2
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDSplitK
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
PassThrough
,
PassThrough
,
MultiplyMultiply
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances_part2
<
GemmKPadding
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment