Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d4adc71a
Commit
d4adc71a
authored
Feb 24, 2023
by
aska-0096
Browse files
Mat-A LDS Bypass sanity pass
parent
c811a0e9
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
887 additions
and
342 deletions
+887
-342
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+40
-6
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+12
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+148
-51
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+100
-92
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+3
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+122
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+396
-160
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+22
-17
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+11
-4
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+8
-5
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+24
-0
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
d4adc71a
...
@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
...
@@ -19,15 +19,49 @@ using AElementOp = PassThrough;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
// clang-format off
// clang-format off
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
using
DeviceGemmInstance
=
ck
::
tensor_operation
::
device
::
DeviceGemmWmma_CShuffle
// ######| ALayout| BLayout| CLayout| AData| BData| CData| AccData| CShuffle| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| MPer| NPer|MRepeat|NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
<
ALayout
,
// ######| | | | Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
BLayout
,
// ######| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
CLayout
,
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
ADataType
,
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
128
,
256
,
8
,
8
,
16
,
16
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
1
>
;
BDataType
,
CDataType
,
AccDataType
,
CShuffleDataType
,
AElementOp
,
BElementOp
,
CElementOp
,
GemmDefault
,
256
,
// BlockSize
128
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
1
,
// M Repeat
8
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
1
,
// C shuffle (M Repeat) Per store
4
,
// C shuffle (N Repeat) Per store
S
<
1
,
64
,
1
,
4
>
,
8
>
;
// clang-format on
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/01_gemm/run_gemm_example.inc
View file @
d4adc71a
...
@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -35,6 +35,18 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
break
;
case
2
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
break
;
case
3
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
default
:
default
:
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
ADataType
>
{
-
1.
f
,
1.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistribution
<
BDataType
>
{
-
1.
f
,
1.
f
}(
b_k_n
);
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
d4adc71a
...
@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
const
int
nrepeat
=
10
;
const
int
nrepeat
=
10
0
;
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
d4adc71a
...
@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
...
@@ -62,12 +62,33 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I4
);
static
constexpr
auto
A_temp0
=
Number
<
ABlockDesc
{}.
GetLength
(
I0
)
>
{};
static
constexpr
auto
A_temp1
=
Number
<
ABlockDesc
{}.
GetLength
(
I1
)
>
{};
static
constexpr
auto
A_temp2
=
Number
<
ABlockDesc
{}.
GetLength
(
I2
)
>
{};
static
constexpr
auto
A_temp3
=
Number
<
ABlockDesc
{}.
GetLength
(
I3
)
>
{};
static
constexpr
auto
A_temp4
=
Number
<
ABlockDesc
{}.
GetLength
(
I4
)
>
{};
// FIX it, workaround
using
ABlockDesc_temp
=
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
A_temp0
,
A_temp1
,
A_temp2
,
A_temp3
,
A_temp4
),
make_tuple
(
A_temp1
*
A_temp2
*
A_temp3
*
A_temp4
,
A_temp2
*
A_temp3
*
A_temp4
,
A_temp3
*
A_temp4
,
A_temp4
,
I1
)));
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
bool
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
bool
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
FloatAcc
,
MRepeat
*
NRepeat
,
MRepeat
*
NRepeat
,
...
@@ -92,24 +113,36 @@ struct BlockwiseGemmWMMA
...
@@ -92,24 +113,36 @@ struct BlockwiseGemmWMMA
// Default, Block buffer in LDS, thread level offset enabled
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
const
auto
wave_idx
=
GetWaveIdx
();
if
constexpr
(
AEnableLds
)
{
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
// |KRepeat |MRepeat|MWave |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
}
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
const
auto
wave_idx
=
GetWaveIdx
();
if
constexpr
(
BEnableLds
)
{
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
// |KRepeat |NRepeat|Nwave |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
);
}
}
}
template
<
index_t
m0
,
index_t
n0
>
template
<
index_t
m0
,
index_t
n0
>
...
@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
...
@@ -269,7 +302,7 @@ struct BlockwiseGemmWMMA
// Describe how data allocated in thread copy src buffer
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
ABlockDesc
_temp
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
...
@@ -285,21 +318,28 @@ struct BlockwiseGemmWMMA
...
@@ -285,21 +318,28 @@ struct BlockwiseGemmWMMA
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_thread_copy_
.
Run
(
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_desc_k0_m0_m1_m2_k1
,
a_block_buf
,
make_tuple
(
a_thread_desc_
,
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
m0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_buf
);
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
b_thread_copy_
.
Run
(
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_desc_k0_n0_n1_n2_k1
,
b_block_buf
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
*
B_Data_Duplicated_Rate
/
2
>
{},
b_thread_desc_
,
n0
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
I0
,
b_thread_buf
);
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
...
@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
...
@@ -324,6 +364,7 @@ struct BlockwiseGemmWMMA
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
});
}
}
...
@@ -340,28 +381,78 @@ struct BlockwiseGemmWMMA
...
@@ -340,28 +381,78 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
template
<
bool
EnableLds
>
FloatA
,
struct
AThreadCopySelector
;
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
template
<
>
FloatB
,
struct
AThreadCopySelector
<
true
>
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
{
decltype
(
b_thread_desc_
),
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
FloatA
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
4
,
decltype
(
a_thread_desc_
),
B_K1
,
Sequence
<
WmmaK
/
A_K1
,
1
,
1
,
1
,
A_K1
>
,
B_K1
>
;
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
AThreadCopy
a_thread_copy_
;
A_K1
,
BThreadCopy
b_thread_copy_
;
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
A_K1
,
0x76543210
,
0xfedcba98
,
true
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
WmmaK
/
B_K1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
>
,
4
,
B_K1
,
0x76543210
,
0xfedcba98
,
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
};
// block wise level pipe designed for inline asm
// block wise level pipe designed for inline asm
...
@@ -376,7 +467,7 @@ template <index_t BlockSize,
...
@@ -376,7 +467,7 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
bool
TransposeC
=
false
,
bool
TransposeC
=
false
,
bool
AssemblyBackend
=
true
>
bool
AssemblyBackend
=
true
>
/* A: K0PerBlock x MPerBlock x K1
/* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
...
@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
...
@@ -407,8 +498,14 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle_FIFO
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
,
AssemblyBackend
>
{};
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
,
AssemblyBackend
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
d4adc71a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -35,10 +36,10 @@ template <typename ALayout,
...
@@ -35,10 +36,10 @@ template <typename ALayout,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerW
MMA
,
ck
::
index_t
MPerW
mma
,
ck
::
index_t
NPerW
MMA
,
ck
::
index_t
NPerW
mma
,
ck
::
index_t
MRepeat
,
ck
::
index_t
MRepeat
,
ck
::
index_t
NRepeat
,
ck
::
index_t
NRepeat
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
...
@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -75,19 +76,31 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
// K1 = Max Vector Access Pixels
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
{
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
assert
(
K
%
K1
==
0
)
;
static
constexpr
auto
WmmaK
=
16
;
const
index_t
K0
=
K
/
K1
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
StrideA
,
I1
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
}
#ifdef ENABLE_COLMAJOR
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
...
@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -97,104 +110,88 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
#endif
#endif
}();
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
const
auto
K
=
a_grid_desc_m_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
AEnableLds
)
{
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_
right_pad_transform
(
M
,
PadM
)),
make_
pass_through_transform
(
a_grid_desc_m_k
.
GetLength
(
I0
)
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
else
else
{
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_pass_through_transform
(
M
)),
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
Raw
,
index_t
N
Raw
,
index_t
StrideB
)
{
{
assert
(
K
%
K1
==
0
);
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
b_grid_desc_n_k
=
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
b_grid_desc_k_n
,
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
assert
(
K
%
K1
==
0
);
make_right_pad_transform
(
N
,
PadN
)),
const
index_t
K0
=
K
/
K1
;
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
transform_tensor_descriptor
(
}
b_grid_desc_n_k
,
else
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
{
make_pass_through_transform
(
N
)),
return
transform_tensor_descriptor
(
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
b_grid_desc_k_n
,
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
Raw
,
index_t
N
Raw
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_m
_n
=
[
&
]()
{
const
auto
c_grid_desc_m
raw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
StrideC
,
I1
));
}
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
NRaw
),
make_tuple
(
I1
,
StrideC
));
}
}
}();
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
return
matrix_padder
.
PadCDescriptor_M_N
(
c_grid_desc_mraw_nraw
);
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
}
// Gridwise descriptor, mapping to whole given provblem.
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
_K0_M_K1
=
decltype
(
MakeAGridDescriptor
_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
...
@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -207,7 +204,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc
_K0_M_K1
,
AGridDesc
,
BGridDesc_K0_N_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
...
@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -215,9 +212,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K
0
PerBlock
,
KPerBlock
,
MPerW
MMA
,
MPerW
mma
,
NPerW
MMA
,
NPerW
mma
,
K1
,
K1
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
...
@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -228,6 +225,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
AEnableLds
,
ABlockLdsAddExtraM
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
...
@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -236,6 +234,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BEnableLds
,
BBlockLdsAddExtraN
,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
...
@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -265,7 +264,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_
k0_m_k1_
{},
a_grid_desc_
{},
b_grid_desc_k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
c_grid_desc_mblock_mperblock_nblock_nperblock
{},
...
@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -276,8 +275,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
a_grid_desc_k0_m_k1_
=
a_grid_desc_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
b_grid_desc_k0_n_k1_
=
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
DeviceGemmWmma_CShuffle
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
c_grid_desc_m_n_
=
DeviceGemmWmma_CShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
...
@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -285,10 +283,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
block_2_ctile_map_
=
block_2_ctile_map_
=
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
if
(
GridwiseGemm
::
CheckValidity
(
b_grid_desc_k0_n_k1_
,
a_grid_desc_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock
=
c_grid_desc_mblock_mperblock_nblock_nperblock
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -300,7 +296,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc
_K0_M_K1
a_grid_desc_
k0_m_k1_
;
AGridDesc
a_grid_desc_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -322,9 +318,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
{
#if 0
#if 0
{
{
std::cout << "arg.a_grid_desc_
k0_m_k1_
{" << arg.a_grid_desc_
k0_m_k1_
.GetLength(I0)
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_
k0_m_k1_
.GetLength(I1) << ", "
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_
k0_m_k1_
.GetLength(I2) << "}" << std::endl;
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
...
@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -336,7 +332,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
}
}
#endif
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
...
@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -348,8 +344,18 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
K
=
const
auto
GetK
=
[
&
]()
{
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I5
);
}
};
const
auto
K
=
GetK
();
float
ave_time
=
0
;
float
ave_time
=
0
;
...
@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -360,7 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -378,7 +384,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
...
@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -393,7 +399,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
...
@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -411,7 +417,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
k0_m_k1_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
...
@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -443,15 +449,17 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp err: AccDataType"
);
return
false
;
return
false
;
}
}
}
}
else
else
{
{
printf
(
"DeviceOp err: Arch"
);
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
...
@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -547,10 +555,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerW
MMA
<<
", "
<<
MPerW
mma
<<
", "
<<
NPerW
MMA
<<
", "
<<
NPerW
mma
<<
", "
<<
MRepeat
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
NRepeat
<<
">"
<<
">"
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
d4adc71a
...
@@ -15,6 +15,8 @@ enum struct PipelineVersion
...
@@ -15,6 +15,8 @@ enum struct PipelineVersion
};
};
template
<
PipelineVersion
PipelineVer
,
template
<
PipelineVersion
PipelineVer
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
index_t
NumPrefetch
=
1
,
index_t
NumPrefetch
=
1
,
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
LoopScheduler
LoopSched
=
LoopScheduler
::
Default
>
constexpr
auto
GridwiseGemmPipeline_Selector
()
constexpr
auto
GridwiseGemmPipeline_Selector
()
...
@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
...
@@ -23,7 +25,7 @@ constexpr auto GridwiseGemmPipeline_Selector()
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
AEnableLds
,
BEnableLds
>
{};
}
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
d4adc71a
...
@@ -8,12 +8,12 @@
...
@@ -8,12 +8,12 @@
namespace
ck
{
namespace
ck
{
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
,
bool
AEnableLds
,
bool
BEnableLds
>
struct
GridwiseGemmPipeline_v1
;
struct
GridwiseGemmPipeline_v1
;
// 1-stage prefetch
// 1-stage prefetch
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
true
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
...
@@ -107,7 +107,7 @@ struct GridwiseGemmPipeline_v1<1>
// 2-stage prefetch
// 2-stage prefetch
template
<
>
template
<
>
struct
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
...
@@ -253,6 +253,123 @@ struct GridwiseGemmPipeline_v1<2>
}
}
};
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
true
>
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
/* num_loop */
)
{
return
true
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
1
;
}
template
<
bool
HasMainLoop
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
typename
AGridBuffer
,
typename
ABlockBuffer
,
typename
ABlockTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
const
AGridBuffer
&
a_grid_buf
,
ABlockBuffer
&
a_block_buf
,
const
ABlockTransferStep
&
a_block_copy_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
{
#if 0
constexpr auto a_block_origin_idx = generate_sequence_v2(
[]() constexpr {
return Number<0>{};
},
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
// preload data into LDS
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// main body
if
constexpr
(
HasMainLoop
)
{
index_t
i
=
0
;
do
{
a_blockwise_copy
.
Run
(
a_grid_desc
,
a_grid_buf
,
a_block_desc
,
a_block_origin_idx
,
a_block_buf_switch
);
block_sync_lds
();
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
a_blockwise_copy
.
MoveSrcSliceWindow
(
a_grid_desc
,
a_block_copy_step
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
a_block_buf
=
a_block_buf_switch
;
++
i
;
}
while
(
i
<
(
num_loop
-
1
));
}
// tail
{
block_sync_lds
();
blockwise_gemm
.
Run
(
a_block_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
}
}
};
// placeholder
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
true
,
false
>
{
};
template
<
>
struct
GridwiseGemmPipeline_v1
<
1
,
false
,
false
>
{
};
template
<
index_t
NumPrefetch
>
template
<
index_t
NumPrefetch
>
struct
GridwiseGemmPipelineInterwave_v1
;
struct
GridwiseGemmPipelineInterwave_v1
;
...
@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
...
@@ -348,7 +465,7 @@ struct GridwiseGemmPipelineInterwave_v1<1>
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
// Note: 2 stage prefetch not optimized for inter-wave loop scheduler
template
<
>
template
<
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
>
struct
GridwiseGemmPipelineInterwave_v1
<
2
>
:
public
GridwiseGemmPipeline_v1
<
2
,
true
,
true
>
{
{
};
};
...
@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
...
@@ -358,7 +475,7 @@ constexpr auto GridwiseGemmPipeline_v1_Selector()
{
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
{
return
GridwiseGemmPipeline_v1
<
NumPrefetch
>
{};
return
GridwiseGemmPipeline_v1
<
NumPrefetch
,
true
,
true
>
{};
}
}
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
else
if
constexpr
(
LoopSched
==
LoopScheduler
::
Interwave
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
d4adc71a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
d4adc71a
...
@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1324,15 +1324,14 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
using
Index
=
MultiIndex
<
nDim
>
;
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
__device__
constexpr
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
(
const
Index
&
src_idx
)
const
ElementwiseOperation
&
element_op
)
:
element_op_
{
element_op
}
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
"wrong! Desc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
static_assert
(
SliceLengths
::
At
(
Number
<
DstVectorDim
>
{})
%
DstScalarPerVector
==
0
,
"wrong! Not divisible"
);
"wrong! Not divisible"
);
ignore
=
src_idx
;
}
}
template
<
typename
SrcSliceOriginIdx
,
template
<
typename
SrcSliceOriginIdx
,
...
@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1344,7 +1343,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
,
const
DstDesc
&
,
const
DstSliceOriginIdx
&
,
const
DstSliceOriginIdx
&
,
DstBuffer
&
dst_buf
)
DstBuffer
&
dst_buf
)
const
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc need to known at compile-time"
);
"wrong! Desc need to known at compile-time"
);
...
@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1383,7 +1382,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// copy data from src_buf into dst_vector
// copy data from src_buf into dst_vector
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
//
idx_md err. as dst access 2 strided elements while src visit 1 per loop
//
src_desc error, non constexpr?
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
src_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
...
@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1396,16 +1395,22 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
// apply element-wise operation
// apply element-wise operation
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// if (get_thread_local_1d_id() < 16)
// apply intra-row swizzle permute
// printf("tid: %03d, RawData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) ); apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
)
if
constexpr
(
IntraRowSwizzlePerm
)
{
{
// origin:
temp
=
__builtin_amdgcn_permlane16
(
// 0x76543210, 0xfedcba98
// 0xfedcba98,
temp
,
// 0x76543210
type_convert
<
int
>
(
v_this_row
),
temp
=
__builtin_amdgcn_permlane16
(
0xb3a29180
,
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
0xf7e6d5c4
,
v_this_row
=
type_convert
<
float
>
(
temp
);
1
,
0
);
v_this_row
=
type_convert
<
SrcData
>
(
temp
);
// if (get_thread_local_1d_id() < 16)
// printf("tid: %03d, SwiData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_this_row)) );
}
}
// apply inter-row permute.
// apply inter-row permute.
...
@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1415,8 +1420,9 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
HighEightRowLaneIdx
,
HighEightRowLaneIdx
,
1
,
1
,
0
);
0
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
v_theother_row
=
type_convert
<
SrcData
>
(
temp
);
// printf("tid: %03d, PermData: %04x\n", get_thread_local_1d_id(),
// *(reinterpret_cast<uint16_t*>(&v_theother_row)) );
if
(
get_thread_local_1d_id
()
%
32
<
16
)
if
(
get_thread_local_1d_id
()
%
32
<
16
)
{
{
// apply type convert
// apply type convert
...
@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1434,8 +1440,7 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
});
});
});
});
}
}
ElementwiseOperation
element_op_
{};
ElementwiseOperation
element_op_
;
};
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
d4adc71a
...
@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
...
@@ -103,7 +103,12 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
bool
AssemblyBackend
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
{
if
constexpr
(
wave_size
==
32
)
if
constexpr
(
wave_size
==
32
)
...
@@ -358,7 +363,7 @@ template <typename src_type_a,
...
@@ -358,7 +363,7 @@ template <typename src_type_a,
index_t
MPerWmma
,
index_t
MPerWmma
,
index_t
NPerWmma
,
index_t
NPerWmma
,
index_t
KPack
,
index_t
KPack
,
bool
TransposeC
=
false
,
bool
TransposeC
=
false
,
bool
AssemblyBackend
=
false
>
bool
AssemblyBackend
=
false
>
struct
WmmaGemm
struct
WmmaGemm
{
{
...
@@ -492,11 +497,13 @@ struct WmmaGemm
...
@@ -492,11 +497,13 @@ struct WmmaGemm
"(int8, int32) or (int4, int32)!"
);
"(int8, int32) or (int4, int32)!"
);
if
constexpr
(
!
TransposeC
)
if
constexpr
(
!
TransposeC
)
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_a_wave
,
p_b_wave
,
p_c_thread
);
}
}
else
else
{
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
,
AssemblyBackend
>(
p_b_wave
,
p_a_wave
,
p_c_thread
);
}
}
}
}
...
...
include/ck/utility/amd_wmma.hpp
View file @
d4adc71a
...
@@ -21,13 +21,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
...
@@ -21,13 +21,16 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16, AssemblyBackend>
template
<
class
FloatC
>
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
__device__
static
void
Run
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
FloatC
&
reg_c
)
{
{
if
constexpr
(
AssemblyBackend
){
if
constexpr
(
AssemblyBackend
)
{
amd_assembly_wmma_f32_16x16x16_f16_w32
(
amd_assembly_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{}));
}
}
else
{
else
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
{
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
}
}
}
}
};
};
...
...
include/ck/utility/data_type.hpp
View file @
d4adc71a
...
@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
...
@@ -988,6 +988,30 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
return
u
.
fp32
;
return
u
.
fp32
;
}
}
template
<
>
inline
__host__
__device__
constexpr
int
type_convert
<
int
,
half_t
>
(
half_t
x
)
{
union
{
half_t
fp16
;
int
int32
;
}
u
=
{
x
};
return
u
.
int32
;
}
template
<
>
inline
__host__
__device__
constexpr
half_t
type_convert
<
half_t
,
int
>
(
int
x
)
{
union
{
int
int32
;
half_t
fp16
;
}
u
=
{
x
};
return
u
.
fp16
;
}
// convert fp32 to bfp16
// convert fp32 to bfp16
template
<
>
template
<
>
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
inline
__host__
__device__
constexpr
bhalf_t
type_convert
<
bhalf_t
,
float
>
(
float
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment