Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1a089f6f
Commit
1a089f6f
authored
Dec 26, 2024
by
aska-0096
Browse files
sanity bug fix
parent
c8c016dd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
232 additions
and
188 deletions
+232
-188
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
...eration/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+20
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+93
-80
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+96
-82
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
.../include/profiler/profile_gemm_multiply_multiply_impl.hpp
+20
-22
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp
View file @
1a089f6f
...
@@ -54,8 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
...
@@ -54,8 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
static
constexpr
index_t
AMmaKStride
=
KPack
;
static
constexpr
index_t
AMmaKStride
=
KPack
;
static
constexpr
index_t
BMmaKStride
=
KPack
;
static
constexpr
index_t
BMmaKStride
=
KPack
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
KRepeat
=
KPerThread
/
KPack
;
static
constexpr
index_t
KPerInnerLoop
=
KPack
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
1a089f6f
...
@@ -227,8 +227,26 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
...
@@ -227,8 +227,26 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleDSplitK<ALayo
}
}
};
};
constexpr
index_t
minimum_occupancy
=
constexpr
index_t
minimum_occupancy
=
[]()
{
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
constexpr
(
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Interwave
)
{
return
2
;
}
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
constexpr
index_t
instance_lds_size
=
MPerBlock
*
KPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
KPerBlock
*
sizeof
(
BDataType
);
return
((
MPerBlock
*
NPerBlock
/
BlockSize
<=
128
)
&&
(
instance_lds_size
<=
32768
))
?
2
:
1
;
}
else
{
return
1
;
}
}();
if
(
has_main_k_block_loop
)
if
(
has_main_k_block_loop
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
1a089f6f
...
@@ -1550,30 +1550,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1550,30 +1550,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
});
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
});
...
@@ -1592,29 +1597,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -1592,29 +1597,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
@@ -2025,31 +2032,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -2025,31 +2032,35 @@ struct GridwiseGemm_xdl_cshuffle_v3
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
});
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
});
...
@@ -2068,29 +2079,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -2068,29 +2079,31 @@ struct GridwiseGemm_xdl_cshuffle_v3
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
1a089f6f
...
@@ -1685,30 +1685,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1685,30 +1685,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
});
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
});
...
@@ -1728,29 +1733,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1728,29 +1733,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -1790,7 +1797,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
EpilogueScheduler
();
//
EpilogueScheduler();
}
}
});
});
}
}
...
@@ -2236,30 +2243,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2236,30 +2243,35 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
static_assert
(
num_access
==
sfc_cde_block
.
GetNumOfAccess
(),
"wrong!"
);
constexpr
auto
KPerInnerLoop
=
blockwise_gemm_pipeline
.
KPerInnerLoop
;
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
});
make_tuple
(
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
Run
(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
});
...
@@ -2279,29 +2291,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2279,29 +2291,31 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleMXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
CShuffleNXdlPerWavePerShuffle
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
static_for
<
0
,
KRepeat
,
1
>
{}([
&
](
auto
k0
)
{
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
vector_type
<
ComputeTypeA
,
KPack
>
a_thread_vec
;
vector_type
<
ComputeTypeB
,
KPack
>
b_thread_vec
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
ik
)
{
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
a_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
ik
))
>
{}];
a_thread_buf
[
Number
<
a_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_m0
+
m0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
b_thread_vec
.
template
AsType
<
ComputeTypeA
>()(
ik
)
=
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
ik
))
>
{}];
b_thread_buf
[
Number
<
b_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_n0
+
n0
,
I0
,
k0
,
k_
+
ik
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
using
mfma_input_type
=
typename
vector_type
<
ComputeTypeA
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc
.
CalculateOffset
(
make_tuple
(
shuffle_m0
+
m0
,
shuffle_n0
+
n0
,
0
));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
...
@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
...
@@ -2341,7 +2355,7 @@ struct GridwiseGemmMultiD_xdl_cshuffle_v3
I0
,
I0
,
cde_lds_and_global_step
);
cde_lds_and_global_step
);
EpilogueScheduler
();
//
EpilogueScheduler();
}
}
});
});
}
}
...
...
profiler/include/profiler/profile_gemm_multiply_multiply_impl.hpp
View file @
1a089f6f
...
@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -234,7 +234,26 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
{
{
c_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
c_device_buf
.
FromDevice
(
e_m_n_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if
constexpr
((
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
||
(
is_same_v
<
ADataType
,
int8_t
>
||
is_same_v
<
BDataType
,
int8_t
>
||
is_same_v
<
EDataType
,
int8_t
>
))
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
...
@@ -276,27 +295,6 @@ bool profile_gemm_multiply_multiply_impl(int do_verification,
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
", KBatch "
<<
kbatch_curr
<<
std
::
endl
;
<<
kbatch_curr
<<
std
::
endl
;
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
// set softer tolerances for fp8
if
constexpr
((
is_same_v
<
ADataType
,
f8_t
>
||
is_same_v
<
BDataType
,
f8_t
>
||
is_same_v
<
EDataType
,
f8_t
>
)
||
(
is_same_v
<
ADataType
,
int8_t
>
||
is_same_v
<
BDataType
,
int8_t
>
||
is_same_v
<
EDataType
,
int8_t
>
))
{
std
::
string
msg
=
"Error: Incorrect results!"
;
double
rtol
=
1e-1
;
double
atol
=
1e-1
;
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
,
msg
,
rtol
,
atol
);
}
else
{
#endif
pass
=
pass
&
ck
::
utils
::
check_err
(
e_m_n_device_result
,
e_m_n_host_result
);
#if defined CK_ENABLE_FP8 || defined CK_ENABLE_INT8
}
#endif
if
(
tflops
>
best_tflops
&&
ave_time
>
1e-10
)
if
(
tflops
>
best_tflops
&&
ave_time
>
1e-10
)
{
{
best_op_name
=
op_name
;
best_op_name
=
op_name
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment