Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
14b422d7
Commit
14b422d7
authored
Feb 28, 2024
by
Jing Zhang
Browse files
debugging
parent
f221c68e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
55 additions
and
50 deletions
+55
-50
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+41
-43
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+4
-2
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+3
-2
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
14b422d7
...
@@ -21,49 +21,47 @@ using CElementOp = PassThrough;
...
@@ -21,49 +21,47 @@ using CElementOp = PassThrough;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
<
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
ALayout
,
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
AElementOp
,
AElementOp
,
BElementOp
,
BElementOp
,
CElementOp
,
CElementOp
,
GemmDefault
,
GemmDefault
,
1
,
// Prefetch stage
1
,
// Prefetch stage
32
,
// BlockSize
128
,
// BlockSize
16
,
// MPerBlock
64
,
// MPerBlock
16
,
// NPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
64
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
1
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
1
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
8
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
8
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
16
,
1
,
2
>
,
S
<
1
,
32
,
1
,
4
>
,
8
>
;
8
>
;
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
...
...
example/01_gemm/run_gemm_example.inc
View file @
14b422d7
...
@@ -73,12 +73,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
...
@@ -73,12 +73,12 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
break
;
break
;
case
3
:
case
3
:
ck
::
utils
::
Fill
UniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
Fill
Constant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1.
f
)
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
break
;
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
Fill
UniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
Fill
Constant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1.
f
)
}(
b_k_n
);
break
;
break
;
case
5
:
case
5
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
2.
f
,
2.
f
}(
a_m_k
);
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
14b422d7
...
@@ -453,6 +453,7 @@ struct BlockwiseGemmWMMA
...
@@ -453,6 +453,7 @@ struct BlockwiseGemmWMMA
A_K1
>
;
A_K1
>
;
};
};
#if 0
template <>
template <>
struct AThreadCopySelector<false>
struct AThreadCopySelector<false>
{
{
...
@@ -467,6 +468,7 @@ struct BlockwiseGemmWMMA
...
@@ -467,6 +468,7 @@ struct BlockwiseGemmWMMA
5,
5,
A_K1>;
A_K1>;
};
};
#endif
template
<
bool
EnableLds
>
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
struct
BThreadCopySelector
;
...
@@ -486,6 +488,7 @@ struct BlockwiseGemmWMMA
...
@@ -486,6 +488,7 @@ struct BlockwiseGemmWMMA
B_K1
>
;
B_K1
>
;
};
};
#if 0
template <>
template <>
struct BThreadCopySelector<false>
struct BThreadCopySelector<false>
{
{
...
@@ -500,6 +503,7 @@ struct BlockwiseGemmWMMA
...
@@ -500,6 +503,7 @@ struct BlockwiseGemmWMMA
5,
5,
B_K1>;
B_K1>;
};
};
#endif
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
14b422d7
...
@@ -97,8 +97,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -97,8 +97,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
true
;
// AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
BEnableLds
=
true
;
// BEnableLds_auto || BEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
14b422d7
...
@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
...
@@ -141,8 +141,8 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// Wave mode dependent propety
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
k_per_wmma
/
2
*
src_a_data_size
/
4
;
//
static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
k_per_wmma
/
2
*
src_b_data_size
/
4
;
//
static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
static
constexpr
index_t
num_acc_vgprs_per_wave
=
...
@@ -158,6 +158,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
...
@@ -158,6 +158,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
}
}
else
if
constexpr
(
wave_size
==
64
)
else
if
constexpr
(
wave_size
==
64
)
{
{
static_assert
(
1
,
""
);
}
}
}
}
};
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment