Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
174b46b0
Commit
174b46b0
authored
Dec 27, 2024
by
coderfeli
Browse files
add cpu shuffle
parent
e6f5a78b
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
99 additions
and
129 deletions
+99
-129
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
...gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
+28
-25
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+16
-21
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+46
-46
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+8
-36
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
View file @
174b46b0
...
@@ -63,36 +63,38 @@ struct MultiplyMultiply
...
@@ -63,36 +63,38 @@ struct MultiplyMultiply
}
}
};
};
void
re
shap
eBuffer
(
c
har
*
buffer
,
int
N
,
int
K
,
char
*
outpu
t
)
{
void
p
re
Shuffl
eBuffer
(
c
onst
FP8
*
src
,
int
N
,
int
K
,
FP8
*
ds
t
)
{
const
int
K
Repeat
=
2
;
const
int
N
Repeat
=
1
;
const
int
N
Repeat
=
3
;
const
int
K
Repeat
=
4
;
const
int
KLane
=
4
;
const
int
KLane
=
2
;
const
int
NLane
=
5
;
const
int
NLane
=
128
;
const
int
KPack
=
6
;
const
int
KPack
=
1
6
;
int
N0
=
N
/
(
NRepeat
*
NLane
);
int
N0
=
N
/
(
NRepeat
*
NLane
);
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
int
tempn
,
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
int
n0
=
n
/
(
NRepeat
*
NLane
);
int
n0
=
n
/
(
NRepeat
*
NLane
);
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
int
nRel
=
n
%
(
NRepeat
*
NLane
);
tempn
=
n
%
(
NRepeat
*
NLane
);
int
kRel
=
k
%
(
KRepeat
*
KLane
*
KPack
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
int
n1
=
tempn
/
NLane
;
int
nIndex
=
nRel
/
NLane
;
int
k1
=
tempk
/
(
KLane
*
KPack
);
int
kIndex
=
kRel
/
(
KLane
*
KPack
);
int
n2
=
n1
%
NLane
;
int
nLaneIndex
=
nRel
%
NLane
;
tempk
=
tempk
%
(
KLane
*
KPack
);
int
kLaneIndex
=
(
kRel
%
(
KLane
*
KPack
))
/
KPack
;
int
k2
=
tempk
/
KPack
;
int
kPackIndex
=
kRel
%
KPack
;
int
k3
=
tempk
%
KPack
;
int
outputIndex
=
(
n0
*
K0
+
k0
)
*
KRepeat
*
NRepeat
*
KLane
*
NLane
*
KPack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
KRepeat
*
NRepeat
*
K0
+
nIndex
*
KRepeat
*
KLane
*
KPack
+
k0
*
KPack
*
NLane
*
KLane
*
KRepeat
*
NRepeat
+
kIndex
*
KLane
*
KPack
+
n1
*
KPack
*
NLane
*
KLane
*
KRepeat
+
nLaneIndex
*
KPack
+
k1
*
KPack
*
NLane
*
KLane
+
kLaneIndex
*
KPack
+
k2
*
KPack
*
NLane
+
kPackIndex
;
+
n2
*
KPack
+
k3
;
output
[
outputIndex
]
=
buffer
[
n
*
K
+
k
];
dst
[
outputIndex
]
=
src
[
n
*
K
+
k
];
}
}
}
}
}
}
...
@@ -191,6 +193,7 @@ int main(int argc, char* argv[])
...
@@ -191,6 +193,7 @@ int main(int argc, char* argv[])
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
A0DataType
>
a0_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
A0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B0DataType
>
b0_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
Tensor
<
B0DataType
>
b0_preshuffled
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
B0Layout
{}));
//use laout only for size
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
D0DataType
>
d0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D0Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D1Layout
{}));
Tensor
<
D1DataType
>
d1_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideD
,
D1Layout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
Tensor
<
EDataType
>
e_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideE
,
ELayout
{}));
...
@@ -217,15 +220,15 @@ int main(int argc, char* argv[])
...
@@ -217,15 +220,15 @@ int main(int argc, char* argv[])
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D0DataType
>
{
-
0.5
,
0.5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
-
0.5
,
0.5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
D1DataType
>
{
-
0.5
,
0.5
});
}
}
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
a0_device_buf
(
sizeof
(
A0DataType
)
*
a0_m_k
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b0_device_buf
(
sizeof
(
B0DataType
)
*
b0_k_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_
k_n
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_
preshuffled
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
ToDevice
(
e_m_n_device_result
.
mData
.
data
());
...
...
include/ck/library/utility/host_tensor_generator.hpp
View file @
174b46b0
...
@@ -131,7 +131,7 @@ struct GeneratorTensor_2<ck::f8_t>
...
@@ -131,7 +131,7 @@ struct GeneratorTensor_2<ck::f8_t>
template
<
typename
...
Is
>
template
<
typename
...
Is
>
ck
::
f8_t
operator
()(
Is
...)
ck
::
f8_t
operator
()(
Is
...)
{
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
float
tmp
=
1
;
return
ck
::
type_convert
<
ck
::
f8_t
>
(
tmp
);
return
ck
::
type_convert
<
ck
::
f8_t
>
(
tmp
);
}
}
};
};
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
174b46b0
...
@@ -281,7 +281,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -281,7 +281,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
const
ABlockDesc
&
a_block_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
ABlockBuffer
&
a_block_buf0
,
ABlockBuffer
&
a_block_buf1
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
const
BBlockDesc
&
b_block_desc
,
...
@@ -306,7 +307,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -306,7 +307,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// // Local prefill 1
// // Local prefill 1
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
0
);
// // Global prefetch 2
// // Global prefetch 2
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
...
@@ -321,19 +322,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -321,19 +322,11 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_block_buf
0
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
});
});
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
// make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
// b_block_buf,
// b_thread_desc_,
// make_tuple(n0, I0, k0, I0),
// b_thread_buf);
// });
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -344,9 +337,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -344,9 +337,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t
i
=
0
;
index_t
i
=
0
;
do
do
{
{
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
...
@@ -364,8 +355,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -364,8 +355,15 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), ype_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
});
});
// if(threadIdx.x==0) {
// printf("\n");
// }
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
...
@@ -387,7 +385,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -387,7 +385,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_block_buf
1
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
...
@@ -397,10 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -397,10 +395,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf0
);
block_sync_lds
();
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
0
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
0
>
{});
...
@@ -441,7 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -441,7 +436,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
make_tuple
(
m0
,
I0
,
I0
,
Number
<
k0
*
AMmaKStride
>
{}),
a_block_buf
,
a_block_buf
0
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
m0
,
I0
,
k0
,
I0
),
make_tuple
(
m0
,
I0
,
k0
,
I0
),
a_thread_buf
);
a_thread_buf
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
174b46b0
...
@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
...
@@ -486,52 +486,52 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
// Tail number could be Odd or Even
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
{
if
(
arg
.
KBatch
>
1
)
//
if(arg.KBatch > 1)
{
//
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
//
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
//
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
AtomicAdd
,
//
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Odd
>
;
//
TailNumber::Odd>;
Run
(
kernel
);
//
Run(kernel);
}
//
}
else
//
else
{
//
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
AtomicAdd
,
//
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Even
>
;
//
TailNumber::Even>;
Run
(
kernel
);
//
Run(kernel);
}
//
}
}
//
}
else
//
else
{
//
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
//
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
//
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
Set
,
//
InMemoryDataOperationEnum::Set,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Odd
>
;
//
TailNumber::Odd>;
Run
(
kernel
);
//
Run(kernel);
}
//
}
else
//
else
{
//
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
<
//
const auto kernel = kernel_gemm_xdl_cshuffle_v3_multi_d_2lds<
GridwiseGemm
,
//
GridwiseGemm,
true
,
//
true,
InMemoryDataOperationEnum
::
Set
,
//
InMemoryDataOperationEnum::Set,
minimum_occupancy
,
//
minimum_occupancy,
TailNumber
::
Even
>
;
//
TailNumber::Even>;
Run
(
kernel
);
//
Run(kernel);
}
//
}
}
//
}
}
}
else
else
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
174b46b0
...
@@ -40,6 +40,7 @@ __global__ void
...
@@ -40,6 +40,7 @@ __global__ void
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
...
@@ -49,42 +50,7 @@ __global__ void
...
@@ -49,42 +50,7 @@ __global__ void
karg
.
p_ds_grid
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
karg
.
p_c_grid
,
p_shared
,
p_shared
,
karg
,
p_shared1
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
c_element_op
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx9__))
}
template
<
typename
GridwiseGemm
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_multi_d_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
,
blockIdx
.
z
);
GridwiseGemm
::
template
Run_2Lds
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
splitk_batch_offset
.
a_k_split_offset
,
karg
.
p_b_grid
+
splitk_batch_offset
.
b_k_split_offset
,
karg
.
p_ds_grid
,
karg
.
p_c_grid
,
p_shared_0
,
p_shared_1
,
karg
,
karg
,
karg
.
a_element_op
,
karg
.
a_element_op
,
karg
.
b_element_op
,
karg
.
b_element_op
,
...
@@ -1256,6 +1222,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1256,6 +1222,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
DsGridPointer
&
p_ds_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
void
*
p_shared
,
void
*
p_shared1
,
const
Problem
&
problem
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1268,6 +1235,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1268,6 +1235,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
p_ds_grid
,
p_ds_grid
,
p_c_grid
,
p_c_grid
,
p_shared
,
p_shared
,
p_shared1
,
problem
,
problem
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -1284,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1284,6 +1252,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
DsGridPointer
&
p_ds_grid
,
DsGridPointer
&
p_ds_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
void
*
p_shared
,
void
*
p_shared1
,
const
Problem
&
problem
,
const
Problem
&
problem
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
...
@@ -1409,6 +1378,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1409,6 +1378,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
// Cast after lds
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
LDSTypeA
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
a_block_buf1
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeA
*>
(
p_shared1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
static_cast
<
LDSTypeB
*>
(
p_shared
)
+
...
@@ -1432,6 +1403,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1432,6 +1403,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_buf1
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_bpreshuffled
,
b_grid_desc_bpreshuffled
,
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment