Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d6ea89ec
"include/ck/utility/amd_inline_asm.hpp" did not exist on "86cc678f1824076467a011bd2d3e176214f7d99c"
Commit
d6ea89ec
authored
Oct 16, 2024
by
Mirza Halilcevic
Browse files
Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.
parent
d20c20a6
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
377 additions
and
43 deletions
+377
-43
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
...emm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
+9
-9
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
...device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+359
-29
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+5
-5
No files found.
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
d6ea89ec
...
@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob
.
N
,
prob
.
N
,
prob
.
K
,
prob
.
K
,
prob
.
O
,
prob
.
O
,
x
.
tile_desc
.
gemm0_m_per_block
,
x
.
tile_desc
.
gemm0
1
_m_per_block
,
x
.
tile_desc
.
gemm0_n_per_block
,
x
.
tile_desc
.
gemm0_n_per_block
,
x
.
tile_desc
.
gemm0_k_per_block
,
x
.
tile_desc
.
gemm0_k_per_block
,
x
.
tile_desc
.
gemm1_n_per_block
,
x
.
tile_desc
.
gemm1_n_per_block
,
...
@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
{
"name"
,
{
"name"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)
+
std
::
to_string
(
this
->
tile_desc
.
ak1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
bk1
)
+
"_"
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
...
@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
{
"E1Layout"
,
ToString
(
this
->
E1
.
layout
)},
{
"E1Layout"
,
ToString
(
this
->
E1
.
layout
)},
{
"ADataType"
,
ToString
(
this
->
A0
.
element
)},
{
"A
0
DataType"
,
ToString
(
this
->
A0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"Acc0DataType"
,
ToString
(
this
->
acc_type
)},
{
"Acc0DataType"
,
ToString
(
this
->
acc_type
)},
{
"D0sDataType"
,
{
"D0sDataType"
,
...
@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{
"PadGemm1N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_n
)},
{
"PadGemm1N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_n
)},
{
"PadGemm1K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_k
)},
{
"PadGemm1K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_k
)},
{
"NumGemm0KPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemm
0
k_prefetch_stage
)},
{
"NumGemm0KPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemmk_prefetch_stage
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"Gemm0MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)},
{
"Gemm0MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"A0K1"
,
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)},
{
"A0K1"
,
std
::
to_string
(
this
->
tile_desc
.
ak1
)},
{
"B0K1"
,
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)},
{
"B0K1"
,
std
::
to_string
(
this
->
tile_desc
.
bk1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
View file @
d6ea89ec
...
@@ -3,8 +3,10 @@
...
@@ -3,8 +3,10 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#endif
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -31,6 +33,7 @@ template <typename A0Layout,
...
@@ -31,6 +33,7 @@ template <typename A0Layout,
typename
CDE1ElementwiseOperation
>
typename
CDE1ElementwiseOperation
>
struct
DeviceBatchedGemmMultipleDGemmMultipleD
:
public
BaseOperator
struct
DeviceBatchedGemmMultipleDGemmMultipleD
:
public
BaseOperator
{
{
#ifndef __HIPCC_RTC__
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
...
@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
...
@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation
cde1_element_op
)
=
0
;
CDE1ElementwiseOperation
cde1_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d6ea89ec
...
@@ -3,8 +3,12 @@
...
@@ -3,8 +3,12 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <sstream>
#include <sstream>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
@@ -13,8 +17,6 @@
...
@@ -13,8 +17,6 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -350,9 +352,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -350,9 +352,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
gemm1_padder
.
PadCDescriptor_M_N
(
e1_grid_desc_mraw_nraw
);
return
gemm1_padder
.
PadCDescriptor_M_N
(
e1_grid_desc_mraw_nraw
);
}
}
static
auto
MakeD0sGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
MRaws
,
static
auto
MakeD0sGridDescriptor_M_N
(
const
A
rray
<
index_t
,
NumD1Tensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
NRaws
,
const
A
rray
<
index_t
,
NumD1Tensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
DsStride
)
const
A
rray
<
index_t
,
NumD1Tensor
>&
DsStride
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -363,9 +365,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -363,9 +365,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
Number
<
NumD0Tensor
>
{});
Number
<
NumD0Tensor
>
{});
}
}
static
auto
MakeD1sGridDescriptor_M_N
(
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
MRaws
,
static
auto
MakeD1sGridDescriptor_M_N
(
const
A
rray
<
index_t
,
NumD1Tensor
>&
MRaws
,
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
NRaws
,
const
A
rray
<
index_t
,
NumD1Tensor
>&
NRaws
,
const
std
::
a
rray
<
index_t
,
NumD1Tensor
>&
DsStride
)
const
A
rray
<
index_t
,
NumD1Tensor
>&
DsStride
)
{
{
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i
)
{
[
&
](
auto
i
)
{
...
@@ -380,9 +382,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -380,9 +382,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
{
{
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA0
,
ComputeBasePtrOfStridedBatch
(
index_t
BatchStrideA0
,
index_t
BatchStrideB0
,
index_t
BatchStrideB0
,
std
::
a
rray
<
index_t
,
NumD0Tensor
>
BatchStrideD0s
,
A
rray
<
index_t
,
NumD0Tensor
>
BatchStrideD0s
,
index_t
BatchStrideB1
,
index_t
BatchStrideB1
,
std
::
a
rray
<
index_t
,
NumD1Tensor
>
BatchStrideD1s
,
A
rray
<
index_t
,
NumD1Tensor
>
BatchStrideD1s
,
index_t
BatchStrideE1
)
index_t
BatchStrideE1
)
:
BatchStrideA0_
(
BatchStrideA0
),
:
BatchStrideA0_
(
BatchStrideA0
),
BatchStrideB0_
(
BatchStrideB0
),
BatchStrideB0_
(
BatchStrideB0
),
...
@@ -429,9 +431,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -429,9 +431,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
private:
private:
index_t
BatchStrideA0_
;
index_t
BatchStrideA0_
;
index_t
BatchStrideB0_
;
index_t
BatchStrideB0_
;
std
::
a
rray
<
index_t
,
NumD0Tensor
>
BatchStrideD0s_
;
A
rray
<
index_t
,
NumD0Tensor
>
BatchStrideD0s_
;
index_t
BatchStrideB1_
;
index_t
BatchStrideB1_
;
std
::
a
rray
<
index_t
,
NumD1Tensor
>
BatchStrideD1s_
;
A
rray
<
index_t
,
NumD1Tensor
>
BatchStrideD1s_
;
index_t
BatchStrideE1_
;
index_t
BatchStrideE1_
;
};
};
...
@@ -520,6 +522,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -520,6 +522,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
B1GridDesc_N_K
{}))
>
;
B1GridDesc_N_K
{}))
>
;
#ifndef __HIPCC_RTC__
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -790,6 +793,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -790,6 +793,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
}
};
};
#endif
static
constexpr
bool
IsValidCompilationParameter
()
static
constexpr
bool
IsValidCompilationParameter
()
{
{
...
@@ -799,9 +803,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -799,9 +803,9 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
// check if DsLayout is supported
// check if DsLayout is supported
template
<
typename
RefLayout
,
typename
DsLayout
,
const
index_t
NumDTensor
>
template
<
typename
RefLayout
,
typename
DsLayout
,
const
index_t
NumDTensor
>
static
bool
CheckDLayout
()
static
constexpr
bool
CheckDLayout
()
{
{
static
bool
valid
=
true
;
bool
valid
=
true
;
// iterate over DLayout tuple
// iterate over DLayout tuple
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
...
@@ -811,13 +815,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -811,13 +815,8 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
valid
;
return
valid
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
()
{
if
(
!
ck
::
is_xdl_supported
())
{
{
return
false
;
}
// Check supported layouts
// Check supported layouts
// A0 - Row
// A0 - Row
// B0 - Col
// B0 - Col
...
@@ -829,16 +828,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -829,16 +828,25 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
B0Layout
>
&&
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
B0Layout
>
&&
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
D0sLayout
,
NumD0Tensor
>
()
&&
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
D0sLayout
,
NumD0Tensor
>
()
&&
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
||
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
B1Layout
>
||
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
B1Layout
>
)
&&
B1Layout
>
)
&&
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
CheckDLayout
<
tensor_layout
::
gemm
::
RowMajor
,
D1sLayout
,
NumD1Tensor
>
()
&&
D1sLayout
,
NumD1Tensor
>
()
&&
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
E1Layout
>
))
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
E1Layout
>
))
{
{
return
false
;
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a0_grid_desc_m_k_
,
return
true
;
}
#ifndef __HIPCC_RTC__
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
IsSupported
()
and
GridwiseGemm
::
CheckValidity
(
arg
.
a0_grid_desc_m_k_
,
arg
.
b0_grid_desc_n_k_
,
arg
.
b0_grid_desc_n_k_
,
arg
.
b1_grid_desc_n_k_
,
arg
.
b1_grid_desc_n_k_
,
arg
.
e1_grid_desc_m_n_
,
arg
.
e1_grid_desc_m_n_
,
...
@@ -989,6 +997,328 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -989,6 +997,328 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
str
.
str
();
return
str
.
str
();
}
}
#endif
template
<
class
A0Desc
,
class
B0Desc
,
class
D0sDesc
,
class
B1Desc
,
class
D1sDesc
,
class
E1Desc
>
struct
Descriptor
{
// for Gemm0
template
<
class
A0GridDescriptor
>
static
constexpr
auto
MakeA0GridDescriptor_M_K
(
const
A0GridDescriptor
&
a0_grid_desc
)
{
return
gemm0_padder
.
PadADescriptor_M_K
(
a0_grid_desc
);
}
// for Gemm0
template
<
class
B0GridDescriptor
>
static
constexpr
auto
MakeB0GridDescriptor_N_K
(
const
B0GridDescriptor
&
b0_grid_desc
)
{
return
gemm0_padder
.
PadBDescriptor_N_K
(
b0_grid_desc
);
}
// for Gemm0
template
<
class
D0sGridDescriptor
>
static
constexpr
auto
MakeD0sGridDescriptor_M_N
(
const
D0sGridDescriptor
&
d0s_grid_desc
)
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
gemm0_padder
.
PadCDescriptor_M_N
(
d
);
},
d0s_grid_desc
);
}
// for Gemm1
template
<
class
B1GridDescriptor
>
static
constexpr
auto
MakeB1GridDescriptor_N_K
(
const
B1GridDescriptor
&
b1_grid_desc
)
{
return
gemm1_padder
.
PadBDescriptor_N_K
(
b1_grid_desc
);
}
// for Gemm1
template
<
class
D1sGridDescriptor
>
static
constexpr
auto
MakeD1sGridDescriptor_M_N
(
const
D1sGridDescriptor
&
d1s_grid_desc
)
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
gemm1_padder
.
PadCDescriptor_M_N
(
d
);
},
d1s_grid_desc
);
}
// for Gemm1
template
<
class
E1GridDescriptor
>
static
constexpr
auto
MakeE1GridDescriptor_M_N
(
const
E1GridDescriptor
&
e1_grid_desc
)
{
return
gemm1_padder
.
PadCDescriptor_M_N
(
e1_grid_desc
);
}
using
A0GridDesc_M_K
=
decltype
(
MakeA0GridDescriptor_M_K
(
A0Desc
{}));
using
B0GridDesc_N_K
=
decltype
(
MakeB0GridDescriptor_N_K
(
B0Desc
{}));
using
D0sGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeD0sGridDescriptor_M_N
(
D0sDesc
{}))
>
;
using
B1GridDesc_N_K
=
decltype
(
MakeB1GridDescriptor_N_K
(
B1Desc
{}));
using
D1sGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeD1sGridDescriptor_M_N
(
D1sDesc
{}))
>
;
using
E1GridDesc_M_N
=
decltype
(
MakeE1GridDescriptor_M_N
(
E1Desc
{}));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
<
A0DataType
,
// TODO: distinguish A/B datatype
Acc0DataType
,
D0sDataType
,
Acc1DataType
,
C1ShuffleDataType
,
D1sDataType
,
E1DataType
,
A0ElementwiseOperation
,
B0ElementwiseOperation
,
CDE0ElementwiseOperation
,
B1ElementwiseOperation
,
CDE1ElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
A0GridDesc_M_K
,
B0GridDesc_N_K
,
D0sGridDesc_M_N
,
B1GridDesc_N_K
,
D1sGridDesc_M_N
,
E1GridDesc_M_N
,
NumGemm0KPrefetchStage
,
BlockSize
,
Gemm0MPerBlock
,
Gemm0NPerBlock
,
Gemm0KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
A0K1
,
B0K1
,
B1K1
,
Gemm0MPerXdl
,
Gemm0NPerXdl
,
Gemm0MXdlPerWave
,
Gemm0NXdlPerWave
,
Gemm1NXdlPerWave
,
A0BlockTransferThreadClusterLengths_AK0_M_AK1
,
A0BlockTransferThreadClusterArrangeOrder
,
A0BlockTransferSrcAccessOrder
,
A0BlockTransferSrcVectorDim
,
A0BlockTransferSrcScalarPerVector
,
A0BlockTransferDstScalarPerVector_AK1
,
true
,
A0BlockLdsExtraM
,
B0BlockTransferThreadClusterLengths_BK0_N_BK1
,
B0BlockTransferThreadClusterArrangeOrder
,
B0BlockTransferSrcAccessOrder
,
B0BlockTransferSrcVectorDim
,
B0BlockTransferSrcScalarPerVector
,
B0BlockTransferDstScalarPerVector_BK1
,
true
,
B0BlockLdsExtraN
,
CDE0BlockTransferSrcVectorDim
,
CDE0BlockTransferSrcScalaerPerVector
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
C1ShuffleMXdlPerWavePerShuffle
,
C1ShuffleGemm0NXdlPerWavePerShuffle
,
CDE1ShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CDE1ShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
>
;
using
A0GridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
A0GridDesc_M_K
{}))
>
;
using
B0GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
B0GridDesc_N_K
{}))
>
;
using
B1GridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
B1GridDesc_N_K
{}))
>
;
using
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
E1GridDesc_M_N
{}))
>
;
using
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
D0sGridDesc_M_N
{}))
>
;
using
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
D1sGridDesc_M_N
{}))
>
;
using
DefaultBlock2E1TileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2E1TileMap
(
E1GridDesc_M_N
{}))
>
;
// tensor descriptors for problem definiton
A0GridDesc_M_K
a0_grid_desc_m_k
;
B0GridDesc_N_K
b0_grid_desc_n_k
;
D0sGridDesc_M_N
d0s_grid_desc_m_n
;
B1GridDesc_N_K
b1_grid_desc_n_k
;
D1sGridDesc_M_N
d1s_grid_desc_m_n
;
E1GridDesc_M_N
e1_grid_desc_m_n
;
// tensor descriptors for block/thread-wise copy
A0GridDesc_AK0_M_AK1
a0_grid_desc_ak0_m_ak1
;
B0GridDesc_BK0_N_BK1
b0_grid_desc_bk0_n_bk1
;
D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
;
B1GridDesc_BK0_N_BK1
b1_grid_desc_bk0_n_bk1
;
D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
d1s_grid_desc_mblock_mperblock_nblock_nperblock
;
E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e1_grid_desc_mblock_mperblock_nblock_nperblock
;
// block-to-e1-tile map
DefaultBlock2E1TileMap
block_2_e1tile_map
;
// element-wise op
A0ElementwiseOperation
a0_element_op
;
B0ElementwiseOperation
b0_element_op
;
CDE0ElementwiseOperation
cde0_element_op
;
B1ElementwiseOperation
b1_element_op
;
CDE1ElementwiseOperation
cde1_element_op
;
bool
has_main_k_block_loop
=
true
;
constexpr
Descriptor
(
A0Desc
a0
,
B0Desc
b0
,
D0sDesc
d0s
,
B1Desc
b1
,
D1sDesc
d1s
,
E1Desc
e1
,
A0ElementwiseOperation
a0_element_op_
,
B0ElementwiseOperation
b0_element_op_
,
CDE0ElementwiseOperation
cde0_element_op_
,
B1ElementwiseOperation
b1_element_op_
,
CDE1ElementwiseOperation
cde1_element_op_
)
:
a0_grid_desc_m_k
{
MakeA0GridDescriptor_M_K
(
a0
)},
b0_grid_desc_n_k
{
MakeB0GridDescriptor_N_K
(
b0
)},
d0s_grid_desc_m_n
{
MakeD0sGridDescriptor_M_N
(
d0s
)},
b1_grid_desc_n_k
{
MakeB1GridDescriptor_N_K
(
b1
)},
d1s_grid_desc_m_n
{
MakeD1sGridDescriptor_M_N
(
d1s
)},
e1_grid_desc_m_n
{
MakeE1GridDescriptor_M_N
(
e1
)},
a0_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultA0GridDescriptor_AK0_M_AK1
(
a0_grid_desc_m_k
)},
b0_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultB0GridDescriptor_BK0_N_BK1
(
b0_grid_desc_n_k
)},
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
{
GridwiseGemm
::
MakeD0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5
(
d0s_grid_desc_m_n
)},
b1_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultB1GridDescriptor_BK0_N_BK1
(
b1_grid_desc_n_k
)},
d1s_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeD1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
d1s_grid_desc_m_n
)},
e1_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeE1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e1_grid_desc_m_n
)},
block_2_e1tile_map
{
GridwiseGemm
::
MakeDefaultBlock2E1TileMap
(
e1_grid_desc_m_n
)},
a0_element_op
{
a0_element_op_
},
b0_element_op
{
b0_element_op_
},
cde0_element_op
{
cde0_element_op_
},
b1_element_op
{
b1_element_op_
},
cde1_element_op
{
cde1_element_op_
},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a0_grid_desc_m_k
.
GetLength
(
I1
))}
{
}
constexpr
bool
IsValid
()
const
{
return
IsSupported
()
and
GridwiseGemm
::
CheckValidity
(
a0_grid_desc_m_k
,
b0_grid_desc_n_k
,
b1_grid_desc_n_k
,
e1_grid_desc_m_n
,
block_2_e1tile_map
);
}
};
template
<
class
A0Desc
,
class
B0Desc
,
class
D0sDesc
,
class
B1Desc
,
class
D1sDesc
,
class
E1Desc
>
static
constexpr
auto
make_descriptor
(
A0Desc
a0
,
B0Desc
b0
,
D0sDesc
d0s
,
B1Desc
b1
,
D1sDesc
d1s
,
E1Desc
e1
,
A0ElementwiseOperation
a0_element_op
=
A0ElementwiseOperation
{},
B0ElementwiseOperation
b0_element_op
=
B0ElementwiseOperation
{},
CDE0ElementwiseOperation
cde0_element_op
=
CDE0ElementwiseOperation
{},
B1ElementwiseOperation
b1_element_op
=
B1ElementwiseOperation
{},
CDE1ElementwiseOperation
cde1_element_op
=
CDE1ElementwiseOperation
{})
{
return
Descriptor
<
A0Desc
,
B0Desc
,
D0sDesc
,
B1Desc
,
D1sDesc
,
E1Desc
>
(
a0
,
b0
,
d0s
,
b1
,
d1s
,
e1
,
a0_element_op
,
b0_element_op
,
cde0_element_op
,
b1_element_op
,
cde1_element_op
);
}
template
<
class
Desc
,
class
D0sPointer
,
class
D1sPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
A0DataType
*
__restrict__
p_a0_grid
,
const
B0DataType
*
__restrict__
p_b0_grid
,
D0sPointer
p_d0s_grid
,
const
B1DataType
*
__restrict__
p_b1_grid
,
D1sPointer
p_d1s_grid
,
E1DataType
*
__restrict__
p_e1_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
#ifndef __HIPCC_RTC__
assert
(
desc
.
IsValid
());
#endif
if
(
desc
.
has_main_k_block_loop
)
{
Desc
::
GridwiseGemm
::
template
Run
<
true
>(
p_a0_grid
,
p_b0_grid
,
p_d0s_grid
,
p_b1_grid
,
p_d1s_grid
,
p_e1_grid
,
p_shared_block
,
desc
.
a0_element_op
,
desc
.
b0_element_op
,
desc
.
cde0_element_op
,
desc
.
b1_element_op
,
desc
.
cde1_element_op
,
desc
.
a0_grid_desc_ak0_m_ak1
,
desc
.
b0_grid_desc_bk0_n_bk1
,
desc
.
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_e1tile_map
);
}
else
{
Desc
::
GridwiseGemm
::
template
Run
<
false
>(
p_a0_grid
,
p_b0_grid
,
p_d0s_grid
,
p_b1_grid
,
p_d1s_grid
,
p_e1_grid
,
p_shared_block
,
desc
.
a0_element_op
,
desc
.
b0_element_op
,
desc
.
cde0_element_op
,
desc
.
b1_element_op
,
desc
.
cde1_element_op
,
desc
.
a0_grid_desc_ak0_m_ak1
,
desc
.
b0_grid_desc_bk0_n_bk1
,
desc
.
d0s_grid_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5
,
desc
.
b1_grid_desc_bk0_n_bk1
,
desc
.
d1s_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e1_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_e1tile_map
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
d6ea89ec
...
@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
(
!
block_2_e1tile_map
.
CheckValidity
(
e1_grid_desc_m_n
))
//
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
//
{
return
false
;
//
return false;
}
//
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
...
@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else
else
{
{
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
cde0_element_op
(
acc_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
[
&
](
auto
i
)
{
cde0_element_op
(
acc
0
_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
}
// gemm1
// gemm1
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment