Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
54f44e62
Commit
54f44e62
authored
Dec 30, 2024
by
coderfeli
Browse files
fix brepeat, kloop and lds two buffer; works ok now
parent
2c056624
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
150 additions
and
37 deletions
+150
-37
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
...gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
+141
-18
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+1
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+8
-7
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp8.cpp
View file @
54f44e62
...
@@ -24,6 +24,11 @@
...
@@ -24,6 +24,11 @@
template
<
ck
::
index_t
...
Is
>
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
// using I8 = int8_t;
// using I32 = int;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
FP8
=
ck
::
f8_t
;
using
FP8
=
ck
::
f8_t
;
using
F32
=
float
;
using
F32
=
float
;
...
@@ -54,25 +59,139 @@ struct MultiplyMultiply
...
@@ -54,25 +59,139 @@ struct MultiplyMultiply
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
,
const
D1
&
d1
)
const
;
template
<
>
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
float
,
float
,
float
>
(
__host__
__device__
constexpr
void
operator
()
<
F16
,
float
,
float
,
float
>
(
ck
::
half_t
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
F16
&
e
,
const
float
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
{
const
float
x0_f
=
c
*
d0
*
d1
;
const
float
x0_f
=
c
*
d0
*
d1
;
e
=
ck
::
type_convert
<
F16
>
(
x0_f
);
}
template
<
>
__host__
__device__
constexpr
void
operator
()
<
ck
::
half_t
,
int
,
float
,
float
>
(
ck
::
half_t
&
e
,
const
int
&
c
,
const
float
&
d0
,
const
float
&
d1
)
const
{
const
float
x0_f
=
ck
::
type_convert
<
float
>
(
c
)
*
ck
::
type_convert
<
float
>
(
d0
)
*
ck
::
type_convert
<
float
>
(
d1
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
}
};
};
// struct MultiplyMultiply
// {
// template <typename E, typename C, typename D0, typename D1>
// __host__ __device__ constexpr void
// operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
// ck::half_t& e, const float& c, const float& d0, const float& d1) const
// {
// const float x0_f = c * d0 * d1;
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::half_t, int, float, float>(
// ck::half_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::half_t>(x0_f);
// }
// template <>
// __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
// ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
// {
// const float x0_f =
// ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
// e = ck::type_convert<ck::bhalf_t>(x0_f);
// }
// };
// void reinit2(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// int kinit = 0;
// for (int k = 0; k < K; k+=1) {
// // dst[n * K + k] = n;
// if(k>0 && k%128==0){
// kinit += 1;
// }
// dst[n * K + k] = k % 128 + kinit;//rand() % 5 - 2;
// }
// }
// }
// void reinit(FP8* dst, int N, int K) {
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; k+=1) {
// dst[n * K + k] = ck::type_convert<FP8>(float(1));
// }
// }
// }
void
dump
(
FP8
*
dst
,
int
N
,
int
K
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
printf
(
"%.1f,"
,
ck
::
type_convert
<
float
>
(
dst
[
n
*
K
+
k
]));
}
printf
(
"
\n
"
);
}
}
// void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
// const int NRepeat = 1;
// const int KRepeat = 8;
// const int NWave = 4;
// const int KLane = 2;
// const int NLane = 32;
// const int KPack = 16;
// int K0 = K / (KRepeat * KLane * KPack);
// int tempn, tempk;
// for (int n = 0; n < N; ++n) {
// for (int k = 0; k < K; ++k) {
// int n0 = n / (NRepeat * NLane * NWave);
// int k0 = k / (KRepeat * KLane * KPack);
// tempn = n % (NRepeat * NLane * NWave);
// tempk = k % (KRepeat * KLane * KPack);
// int n1 = tempn / (NLane * NWave);
// int k1 = tempk / (KLane * KPack);
// tempn = tempn % (NLane * NWave);
// tempk = tempk % (KLane * KPack);
// int n2 = tempn / NLane;
// int k2 = tempk / KPack;
// int n3 = tempn % NLane;
// int k3 = tempk % KPack;
// int outputIndex = n0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat * K0
// + k0 * KPack * NLane * KLane * NWave * KRepeat * NRepeat
// + n1 * KPack * NLane * KLane * NWave * KRepeat
// + k1 * KPack * NLane * KLane * NWave
// + n2 * KPack * NLane * KLane
// + k2 * KPack * NLane
// + n3 * KPack
// + k3;
// dst[outputIndex] = src[n * K + k];
// }
// }
// }
void
preShuffleBuffer
(
const
FP8
*
src
,
int
N
,
int
K
,
FP8
*
dst
)
{
void
preShuffleBuffer
(
const
FP8
*
src
,
int
N
,
int
K
,
FP8
*
dst
)
{
const
int
NRepeat
=
1
;
const
int
NRepeat
=
1
;
const
int
KRepeat
=
4
;
const
int
KRepeat
=
8
;
const
int
NWave
=
4
;
const
int
NWave
=
4
;
const
int
KLane
=
2
;
const
int
KLane
=
2
;
const
int
NLane
=
32
;
const
int
NLane
=
32
;
const
int
KPack
=
16
;
const
int
KPack
=
16
;
int
N0
=
N
/
(
NRepeat
*
NLane
*
NWave
);
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
int
K0
=
K
/
(
KRepeat
*
KLane
*
KPack
);
// K -> src: K0 KLane KRepeat KPack -> dst: K0 KRpeat KLane KPack, move klane inner to make all lanes contiguous
// N -> N0 NRepeat NWave NLane // todo : is NRepeat outer or inner? now it's 1
int
tempn
,
tempk
;
int
tempn
,
tempk
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
...
@@ -80,21 +199,22 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
...
@@ -80,21 +199,22 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
int
k0
=
k
/
(
KRepeat
*
KLane
*
KPack
);
tempn
=
n
%
(
NRepeat
*
NLane
*
NWave
);
tempn
=
n
%
(
NRepeat
*
NLane
*
NWave
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
tempk
=
k
%
(
KRepeat
*
KLane
*
KPack
);
int
n1
=
tempn
/
(
NLane
*
NWave
);
int
n1
=
tempn
/
(
NLane
*
NWave
);
int
k1
=
tempk
/
(
K
Lane
*
KPack
);
int
k1
=
tempk
/
(
K
Repeat
*
KPack
);
// Klane
tempn
=
tempn
%
(
NLane
*
NWave
);
tempn
=
tempn
%
(
NLane
*
NWave
);
tempk
=
tempk
%
(
K
Lane
*
KPack
);
tempk
=
tempk
%
(
K
Repeat
*
KPack
);
int
n2
=
tempn
/
NLane
;
int
n2
=
tempn
/
NLane
;
int
k2
=
tempk
/
KPack
;
int
k2
=
tempk
/
KPack
;
// KRepeat
int
n3
=
tempn
%
NLane
;
int
n3
=
tempn
%
NLane
;
int
k3
=
tempk
%
KPack
;
int
k3
=
tempk
%
KPack
;
// Kpack
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
*
K0
int
outputIndex
=
n0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
*
K0
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
+
k0
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
*
NRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
n1
*
KPack
*
NLane
*
KLane
*
NWave
*
KRepeat
+
k
1
*
KPack
*
NLane
*
KLane
*
NWave
+
k
2
*
KPack
*
NLane
*
KLane
*
NWave
//switch k1, k2
+
n2
*
KPack
*
NLane
*
KLane
+
n2
*
KPack
*
NLane
*
KLane
+
k
2
*
KPack
*
NLane
+
k
1
*
KPack
*
NLane
+
n3
*
KPack
+
n3
*
KPack
+
k3
;
+
k3
;
...
@@ -102,7 +222,6 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
...
@@ -102,7 +222,6 @@ void preShuffleBuffer(const FP8* src, int N, int K, FP8* dst) {
}
}
}
}
}
}
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
...
@@ -120,6 +239,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
...
@@ -120,6 +239,7 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShu
///###### RCR
///###### RCR
// kernel 1: 256->32x128x128
// kernel 1: 256->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 128, 16, 16, 32, 32, 1, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 32, 128, 256, 16, 16, 32, 32, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
32
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
<
Row
,
Col
,
DsLayout
,
ELayout
,
A0DataType
,
B0DataType
,
DsDataType
,
EDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CDEElementOp
,
GemmSpec
,
256
,
32
,
128
,
256
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
ck
::
BlockGemmPipelineScheduler
::
Intrawave
,
ck
::
BlockGemmPipelineVersion
::
v3
,
FP8
>
;
// kernel 2: 128->32x128x128
// kernel 2: 128->32x128x128
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
// < Row, Col, DsLayout, ELayout, A0DataType, B0DataType, DsDataType, EDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 128, 32, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, FP8>;
...
@@ -215,8 +335,8 @@ int main(int argc, char* argv[])
...
@@ -215,8 +335,8 @@ int main(int argc, char* argv[])
case
1
:
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
0
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
2
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
0
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
2
,
2
});
break
;
break
;
default:
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
...
@@ -229,9 +349,12 @@ int main(int argc, char* argv[])
...
@@ -229,9 +349,12 @@ int main(int argc, char* argv[])
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d0_device_buf
(
sizeof
(
D0DataType
)
*
d0_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
d1_device_buf
(
sizeof
(
D1DataType
)
*
d1_m_n
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
e_device_buf
(
sizeof
(
EDataType
)
*
e_m_n_device_result
.
mDesc
.
GetElementSpaceSize
());
// reinit2(a0_m_k.mData.data(), M, K);
// reinit2(b0_k_n.mData.data(), N, K);
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
preShuffleBuffer
(
b0_k_n
.
mData
.
data
(),
N
,
K
,
b0_preshuffled
.
mData
.
data
());
// dump(b0_preshuffled.mData.data(), N, K);
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
a0_device_buf
.
ToDevice
(
a0_m_k
.
mData
.
data
());
// b0_device_buf.ToDevice(b0_preshuffled.mData.data());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
b0_device_buf
.
ToDevice
(
b0_preshuffled
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d0_device_buf
.
ToDevice
(
d0_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
d1_device_buf
.
ToDevice
(
d1_m_n
.
mData
.
data
());
...
@@ -273,7 +396,7 @@ int main(int argc, char* argv[])
...
@@ -273,7 +396,7 @@ int main(int argc, char* argv[])
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
,
0
,
1
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
...
@@ -288,7 +411,7 @@ int main(int argc, char* argv[])
...
@@ -288,7 +411,7 @@ int main(int argc, char* argv[])
if
(
do_verification
)
if
(
do_verification
)
{
{
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
});
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
false
,
0
,
1
,
1
});
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
e_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
54f44e62
...
@@ -328,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -328,7 +328,6 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf
);
a_thread_buf
);
});
});
});
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// main body
// main body
...
@@ -355,15 +354,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -355,15 +354,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_vec
.
template
AsType
<
ComputeDataType
>()(
ik
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
// if(threadIdx.x==0) {
// printf("%f, %f; ", type_convert<float>(a_thread_vec.template AsType<ComputeDataType>()(ik)), type_convert<float>(b_thread_vec.template AsType<ComputeDataType>()(ik)));
// }
});
});
// if(threadIdx.x==0) {
// printf("\n");
// }
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
...
@@ -442,20 +434,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
...
@@ -442,20 +434,17 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_thread_buf
);
a_thread_buf
);
});
});
});
});
HotLoopScheduler
();
HotLoopScheduler
();
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
i
+=
2
;
i
+=
2
;
}
while
(
i
<
(
num_loop
-
1
));
}
while
(
i
<
(
num_loop
-
2
));
}
}
// tail
// tail
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
{
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
a_blockwise_copy
.
RunWrite
(
a_block_desc
,
a_block_buf1
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
,
Number
<
1
>
{});
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
54f44e62
...
@@ -130,10 +130,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -130,10 +130,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
auto
BlockSizeNumber
=
Number
<
BlockSize
>
{};
static
constexpr
auto
BlockSizeNumber
=
Number
<
BlockSize
>
{};
static
constexpr
index_t
NLane
=
128
;
static
constexpr
index_t
NLane
=
32
;
static
constexpr
index_t
NWave
=
4
;
static
constexpr
index_t
KLane
=
2
;
static
constexpr
index_t
KLane
=
2
;
static
constexpr
index_t
KRepeat
=
4
;
static
constexpr
index_t
KRepeat
=
8
;
static_assert
(
NLane
*
KLane
==
BlockSize
);
static_assert
(
NLane
*
NWave
*
KLane
==
BlockSize
);
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
...
@@ -173,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -173,11 +174,11 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__host__
__device__
static
auto
CalculateBN0Shuffled
(
index_t
N
)
__host__
__device__
static
auto
CalculateBN0Shuffled
(
index_t
N
)
{
{
return
math
::
integer_
least_multiple
(
N
,
NLan
e
);
return
math
::
integer_
divide_ceil
(
N
,
NLane
*
NWav
e
);
}
}
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
,
index_t
KBatch
)
__host__
__device__
static
auto
CalculateBK0Shuffled
(
index_t
K
,
index_t
KBatch
)
{
{
return
math
::
integer_
least_multiple
(
K
,
KLane
*
KPack
*
KBatch
);
return
math
::
integer_
divide_ceil
(
K
,
KLane
*
KPack
*
KBatch
);
}
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
...
@@ -1296,8 +1297,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1296,8 +1297,8 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
)
/
NLane
;
__builtin_amdgcn_readfirstlane
(
block_n_id
*
(
NPerBlock
/
NLane
/
NWave
))
;
// lds max alignment
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment