Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0d6da13c
Unverified
Commit
0d6da13c
authored
May 05, 2023
by
rocking
Committed by
GitHub
May 05, 2023
Browse files
Merge branch 'develop' into normalization/splitK
parents
27f8c64b
b076a02a
Changes
118
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
938 additions
and
434 deletions
+938
-434
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
...ce/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
...e/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+88
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+70
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
...e/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
...on/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
...tion/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
+60
-412
include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
+4
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
...vice_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
...e_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
...e_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+2
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+613
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+23
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+48
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
...pu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
+2
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+1
-1
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -63,7 +63,8 @@ __global__ void
...
@@ -63,7 +63,8 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
,
const
Block2ETileMap
block_2_etile_map
,
index_t
NRaw
)
index_t
NRaw
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemmWelford
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemmWelford
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemmWelford
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemmWelford
::
template
Run
<
HasMainKBlockLoop
>(
...
@@ -854,7 +855,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
...
@@ -854,7 +855,8 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -60,7 +60,8 @@ __global__ void
...
@@ -60,7 +60,8 @@ __global__ void
const
RsGridDescriptor_MBlock_MPerBlock
rs_grid_desc_mblock_mperblock
,
const
RsGridDescriptor_MBlock_MPerBlock
rs_grid_desc_mblock_mperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
...
@@ -554,7 +555,8 @@ struct DeviceGemmMultipleDMultipleR_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0d6da13c
...
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -273,7 +273,10 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
}
cde_element_op_
{
cde_element_op
},
MRaw_
{
M
},
NRaw_
{
N
},
KRaw_
{
K
}
{
{
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
a_grid_desc_k0_m_k1_
=
DeviceOp
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
);
...
@@ -335,6 +338,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -335,6 +338,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
...
@@ -488,6 +496,85 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -488,6 +496,85 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
false
;
return
false
;
}
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector load of Ds
// only support RowMajor for now
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
all_valid
=
false
;
}
});
if
(
!
all_valid
)
{
return
false
;
}
// check vector store of E
// only support RowMajor for now
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CDEShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -51,7 +51,8 @@ __global__ void
...
@@ -51,7 +51,8 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -490,7 +491,8 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
0d6da13c
...
@@ -239,7 +239,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -239,7 +239,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
N01_
{
N01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
},
MRaw_
{
M
},
NRaw_
{
N
},
KRaw_
{
K
}
{
{
a_grid_desc_k0_m_k1_
=
a_grid_desc_k0_m_k1_
=
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
DeviceGemmWmma_CShuffle
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
);
...
@@ -276,6 +279,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -276,6 +279,10 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
// for checking vector load/store
index_t
MRaw_
;
index_t
NRaw_
;
index_t
KRaw_
;
};
};
// Invoker
// Invoker
...
@@ -417,6 +424,68 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -417,6 +424,68 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
// check vector store of C
// only support RowMajor for now
if
constexpr
(
is_same_v
<
CLayout
,
Row
>
)
{
if
(
arg
.
NRaw_
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
return
false
;
}
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
c_grid_desc_m_n_
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp
View file @
0d6da13c
...
@@ -428,7 +428,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
...
@@ -428,7 +428,7 @@ struct DeviceGemmXdl : public DeviceGemm<ALayout,
return
false
;
return
false
;
}
}
}
}
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
)
else
if
(
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -574,7 +574,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -574,7 +574,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_layernorm_cshuffle.hpp
View file @
0d6da13c
...
@@ -648,7 +648,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
...
@@ -648,7 +648,8 @@ struct DeviceGemmLayerNorm_Xdl_CShuffle : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
0d6da13c
...
@@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -73,157 +73,18 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
auto
MakeAGridDescriptor_KBatch_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_kpad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
static
auto
MakeBGridDescriptor_KBatch_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
,
int
KBatch
,
int
KPad
)
{
assert
(
KPad
%
(
K1
*
KBatch
)
==
0
);
const
index_t
K0
=
KPad
/
(
K1
*
KBatch
);
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_kpad_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
KBatch
,
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
3
>
{},
Sequence
<
2
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_right_pad_transform
(
M
,
PadM
),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_pass_through_transform
(
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
GetKPad
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
return
KPad
;
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_KBatch_K0_M_K1
(
1
,
1
,
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
1
,
1
,
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CDataType
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
ALayout
,
AGridDesc_K0_M_K1
,
BLayout
,
BGridDesc_K0_N_K1
,
CLayout
,
CGridDesc_M_N
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
GemmSpec
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
...
@@ -253,236 +114,68 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -253,236 +114,68 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
// GridwiseGemm
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
GridwiseGemmAtomicAdd
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
using
DefaultBlock2CTileMap
=
typename
GridwiseGemm
::
DefaultBlock2CTileMap
;
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
K1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsAddExtraN
,
CShuffleMRepeatPerShuffle
,
CShuffleNRepeatPerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXDL
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
typename
GridwiseGemm
::
CBlockClusterAdaptor
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
M01
,
index_t
N01
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
index_t
k_batch
)
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
a_grid_desc_kbatch_k0_m_k1_
{},
b_grid_desc_kbatch_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
k_batch_
{
k_batch
}
{
int
KPad
=
DeviceGemmXdlSplitKCShuffle
::
GetKPad
(
K
,
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
DeviceGemmXdlSplitKCShuffle
::
MakeAGridDescriptor_KBatch_K0_M_K1
(
M
,
K
,
StrideA
,
k_batch_
,
KPad
);
b_grid_desc_kbatch_k0_n_k1_
=
DeviceGemmXdlSplitKCShuffle
::
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
N
,
StrideB
,
k_batch_
,
KPad
);
c_grid_desc_m_n_
=
DeviceGemmXdlSplitKCShuffle
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
c_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
b_grid_desc_kbatch_k0_n_k1_
,
c_grid_desc_m_n_
,
block_2_ctile_map_
))
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
}
}
// private:
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_kbatch_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_kbatch_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
index_t
k_batch_
;
};
// Invoker
// Invoker
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
using
Argument
=
DeviceGemmXdlSplitKCShuffle
::
Argument
;
void
Print
(
const
Argument
&
arg
)
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
k
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
Print
(
arg
);
Print
(
k
arg
);
}
}
const
auto
kbatch
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
;
const
auto
kbatch
=
k
arg
.
k_batch
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid setting"
);
"wrong! GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2 has invalid "
"setting"
);
}
}
const
index_t
grid_size
=
const
auto
b2c_map
=
DefaultBlock2CTileMap
{};
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
;
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
b2c_map
.
CalculateGridSize
(
karg
.
M
,
karg
.
N
,
karg
.
k_batch
);
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
;
const
auto
K0
=
k
arg
.
K0
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
hipGetErrorString
(
hipMemset
(
if
(
kbatch
>
1
)
arg
.
p_c_grid_
,
hipGetErrorString
(
0
,
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
CDataType
)));
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
.
GetElementSpaceSize
()
*
sizeof
(
CDataType
)));
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
,
b2c_map
);
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
};
};
if
(
has_main_k0_block_loop
)
if
(
has_main_k0_block_loop
)
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
true
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
DefaultBlock2CTileMap
>
;
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
GridwiseGemmAtomicAdd
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
true
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
DefaultBlock2CTileMap
>
;
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
true
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -491,37 +184,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -491,37 +184,21 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
{
{
if
(
kbatch
==
1
)
if
(
kbatch
==
1
)
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
GridwiseGemm
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
false
,
CDataType
,
InMemoryDataOperationEnum
::
Set
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
DefaultBlock2CTileMap
>
;
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_xdlops_v2r4r2
<
const
auto
kernel
=
GridwiseGemmAtomicAdd
,
kernel_gemm_xdlops_v2r4r2_simplified
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
false
,
CDataType
,
InMemoryDataOperationEnum
::
AtomicAdd
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
AGridDesc_K0_M_K1
>
,
DefaultBlock2CTileMap
>
;
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
DeviceGemmXdlSplitKCShuffle
::
Block2CTileMap
>
,
false
>
;
Run
(
kernel
);
Run
(
kernel
);
}
}
...
@@ -544,12 +221,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -544,12 +221,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
return
true
;
return
true
;
}
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
k
arg
)
{
{
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
return
GridwiseGemm
::
CheckValidity
(
karg
);
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
}
// polymorphic
// polymorphic
...
@@ -567,9 +241,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -567,9 +241,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
,
index_t
KBatch
)
index_t
KBatch
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
...
@@ -581,11 +255,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -581,11 +255,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
1
,
GridwiseGemm
::
CalculateMPadded
(
M
),
1
,
GridwiseGemm
::
CalculateNPadded
(
N
),
a_element_op
,
GridwiseGemm
::
CalculateKPadded
(
K
),
b_element_op
,
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
c_element_op
,
KBatch
};
KBatch
};
}
}
...
@@ -601,9 +274,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -601,9 +274,9 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
index_t
StrideA
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideB
,
index_t
StrideC
,
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
,
ck
::
index_t
KBatch
=
1
)
override
ck
::
index_t
KBatch
=
1
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
...
@@ -615,11 +288,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -615,11 +288,10 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
1
,
GridwiseGemm
::
CalculateMPadded
(
M
),
1
,
GridwiseGemm
::
CalculateNPadded
(
N
),
a_element_op
,
GridwiseGemm
::
CalculateKPadded
(
K
),
b_element_op
,
GridwiseGemm
::
CalculateK0
(
K
,
KBatch
),
c_element_op
,
KBatch
);
KBatch
);
}
}
...
@@ -630,31 +302,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
...
@@ -630,31 +302,7 @@ struct DeviceGemmXdlSplitKCShuffle : public DeviceGemmSplitK<ALayout,
}
}
// polymorphic
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
return
GridwiseGemm
::
GetTypeString
();
}
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGemmXdlSplitKCShuffle"
<<
"<"
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_contraction_multiple_d_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -37,7 +37,8 @@ __global__ void
...
@@ -37,7 +37,8 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
@@ -703,7 +704,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
...
@@ -703,7 +704,8 @@ struct DeviceGroupedContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp
View file @
0d6da13c
...
@@ -130,7 +130,8 @@ __global__ void
...
@@ -130,7 +130,8 @@ __global__ void
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -78,7 +78,8 @@ __global__ void
...
@@ -78,7 +78,8 @@ __global__ void
const
Block2CTileMap
block_2_ctile_map
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_multiple_r_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -155,7 +155,8 @@ __global__ void
...
@@ -155,7 +155,8 @@ __global__ void
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -810,7 +811,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
...
@@ -810,7 +811,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
return
false
;
return
false
;
}
}
}
}
else
if
(
get_device_name
()
==
"gfx90a"
)
else
if
(
get_device_name
()
==
"gfx90a"
||
get_device_name
()
==
"gfx940"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -135,7 +135,8 @@ __global__ void
...
@@ -135,7 +135,8 @@ __global__ void
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -684,7 +685,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
...
@@ -684,7 +685,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle
return
false
;
return
false
;
}
}
}
}
else
if
(
get_device_name
()
==
"gfx90a"
)
else
if
(
get_device_name
()
==
"gfx90a"
||
get_device_name
()
==
"gfx940"
)
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
is_same_v
<
AccDataType
,
int32_t
>
||
is_same_v
<
AccDataType
,
double
>
))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
0d6da13c
...
@@ -38,7 +38,8 @@ __global__ void
...
@@ -38,7 +38,8 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
0 → 100644
View file @
0d6da13c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/ck.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/hip_check_error.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/utility/tuple.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_splitk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
block_start_
&&
block_id
<
gemm_desc_ptr
[
group_id
].
block_end_
))
&&
left
<=
right
)
{
if
(
block_id
<
gemm_desc_ptr
[
group_id
].
block_start_
)
{
right
=
group_id
;
}
else
{
left
=
group_id
;
}
group_id
=
index_t
((
left
+
right
)
/
2
);
}
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumPrefetch
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_K1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_K0_N_K1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_K1
,
bool
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
// Current implementation does not support multiple D fusions.
enable_if_t
<
AK1
==
BK1
&&
is_same_v
<
DsLayout
,
ck
::
Tuple
<
>
>
&&
is_same_v
<
DsDataType
,
ck
::
Tuple
<>>
,
bool
>
=
false
>
struct
DeviceGroupedGemmXdlSplitKCShuffle
:
public
DeviceGroupedGemmSplitK
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
KPerBlock
%
AK1
==
0
);
static
constexpr
index_t
K0PerBlock
=
KPerBlock
/
AK1
;
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
<
BlockSize
,
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
EDataType
,
ALayout
,
BLayout
,
ELayout
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
GemmSpec
,
MPerBlock
,
NPerBlock
,
K0PerBlock
,
MPerXDL
,
NPerXDL
,
AK1
,
MXdlPerWave
,
NXdlPerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_K1
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CDEBlockTransferScalarPerVector_NPerBlock
,
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
>
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
// Block2CTileMap configuration parameter.
static
constexpr
index_t
B2E_M01
=
8
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
using
KernelArgument
=
typename
GridwiseGemm
::
Argument
;
struct
GemmTransKernelArg
{
KernelArgument
karg_
;
GroupedGemmBlock2ETileMap
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GroupedGemmBlock2ETileMap
&&
b2c_map
,
index_t
block_start
,
index_t
block_end
)
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
block_start_
{
block_start
},
block_end_
{
block_end
}
{
}
};
static
constexpr
index_t
DefaultKBatch
=
1
;
// Argument
struct
Argument
:
public
BaseArgument
{
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
)
:
Argument
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
,
DefaultKBatch
)
{
// TODO: use occupancy api to calculate appropriate batch size.
}
Argument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
index_t
kbatch
)
:
K_BATCH
{
kbatch
}
{
grid_size_
=
0
;
group_count_
=
ck
::
type_convert
<
ck
::
index_t
>
(
gemm_descs
.
size
());
if
(
!
(
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_As
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Bs
.
size
())
&&
group_count_
==
ck
::
type_convert
<
ck
::
index_t
>
(
p_Es
.
size
())))
{
throw
std
::
runtime_error
(
"wrong! group_count_ != p_As/b/c.size"
);
}
gemm_kernel_args_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
{
const
index_t
M
=
gemm_descs
[
i
].
M_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
==
0
)
{
skipped_group_count_
++
;
continue
;
}
const
index_t
stride_a
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
stride_b
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
stride_c
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
m_padded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
index_t
n_padded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
m_padded
,
n_padded
,
stride_c
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
auto
karg
=
KernelArgument
{
type_convert
<
const
ADataType
*>
(
p_As
[
i
]),
type_convert
<
const
BDataType
*>
(
p_Bs
[
i
]),
type_convert
<
EDataType
*>
(
p_Es
[
i
]),
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
,
m_padded
,
n_padded
,
k_padded
,
k0
,
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
std
::
move
(
grouped_block_2_ctile_map
),
block_start
,
block_end
);
}
}
/**
* @brief Recalculate group grid size for all gemms and update B2C maps.
*
* @param[in] kbatch The new splitK parameter value.
*/
void
UpdateKBatch
(
index_t
kbatch
)
{
K_BATCH
=
kbatch
;
grid_size_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg_
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
MPadded
,
karg
.
NPadded
,
karg
.
StrideC
);
const
auto
local_b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
K_BATCH
};
const
index_t
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
c_grid_desc_m_n
);
const
index_t
block_start
=
grid_size_
;
const
index_t
block_end
=
grid_size_
+
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
// block-to-e-tile map
auto
grouped_block_2_ctile_map
=
GroupedGemmBlock2ETileMap
(
local_b2c_tile_map
,
block_start
);
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
grouped_block_2_ctile_map
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
// private:
index_t
K_BATCH
;
index_t
group_count_
;
index_t
skipped_group_count_
;
std
::
vector
<
GemmTransKernelArg
>
gemm_kernel_args_
;
index_t
grid_size_
;
};
// Invoker
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
if
(
stream_config
.
log_level_
>
0
)
{
karg
.
Print
();
}
auto
kbatch
=
karg
.
k_batch
;
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
std
::
ostringstream
err
;
err
<<
"Group id: "
<<
i
<<
" has invalid GridwiseGemm settings!"
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
K0
=
karg
.
K0
;
bool
not_all_have_main_k0_block_loop_same
=
all_have_main_k0_block_loop
xor
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
not_all_have_kbatch_value_same
=
all_have_kbatch_gt_one
xor
(
kbatch
>
1
);
if
(
not_all_have_main_k0_block_loop_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same value for main_k0_block_loop! in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
if
(
not_all_have_kbatch_value_same
)
{
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
}
}
hip_check_error
(
hipMemcpy
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args_
.
data
(),
arg
.
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
),
hipMemcpyHostToDevice
));
float
ave_time
=
0
;
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
all_have_kbatch_gt_one
)
{
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
const
auto
&
karg
=
trans_arg
.
karg_
;
hip_check_error
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
)));
}
}
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
gemm_kernel_args_
.
size
());
};
if
(
all_have_main_k0_block_loop
)
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
else
{
if
(
all_have_kbatch_gt_one
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
Run
(
kernel
);
}
}
return
ave_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
static
constexpr
bool
IsValidCompilationParameter
()
{
// TODO: properly implement this check
return
true
;
}
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
((
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_kernel_args_
.
size
())
+
arg
.
skipped_group_count_
)
!=
arg
.
group_count_
)
{
return
false
;
}
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg_
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
#if DEBUG_LOG
if
(
not
group_arg_valid
)
{
std
::
cout
<<
"["
<<
__func__
<<
"] group id: "
<<
i
<<
" is not supported!
\n
"
;
a
.
Print
();
}
#endif // DEBUG_LOG
supported
&=
group_arg_valid
;
}
return
supported
;
}
// polymorphic
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
static
auto
MakeArgument
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
{
return
Argument
{
p_As
,
p_Bs
,
p_Es
,
gemm_descs
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
// polymorphic
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
std
::
vector
<
const
void
*>&
p_As
,
std
::
vector
<
const
void
*>&
p_Bs
,
std
::
vector
<
std
::
array
<
const
void
*
,
NumDTensor
>>&
,
std
::
vector
<
void
*>&
p_Es
,
std
::
vector
<
GemmDesc
>&
gemm_descs
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
)
override
{
return
std
::
make_unique
<
Argument
>
(
p_As
,
p_Bs
,
p_Es
,
gemm_descs
);
}
// polymorphic
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceGroupedGemm_XdlSplitK"
<<
"<"
<<
std
::
string
(
ALayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
BLayout
::
name
)[
0
]
<<
","
<<
std
::
string
(
ELayout
::
name
)[
0
]
<<
","
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
AK1
<<
", "
<<
BK1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
getGemmSpecializationString
(
GemmSpec
)
<<
">"
;
// clang-format on
return
str
.
str
();
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
gemm_kernel_args_
.
size
()
*
sizeof
(
GemmTransKernelArg
);
}
static
void
SetKBatchSize
(
Argument
&
arg
,
index_t
kbatch
)
{
arg
.
UpdateKBatch
(
kbatch
);
}
// polymorphic
void
SetKBatchSize
(
BaseArgument
*
p_arg
,
index_t
kbatch
)
const
override
{
return
SetKBatchSize
(
*
dynamic_cast
<
Argument
*>
(
p_arg
),
kbatch
);
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
0d6da13c
...
@@ -56,6 +56,12 @@ struct PassThrough
...
@@ -56,6 +56,12 @@ struct PassThrough
y
=
type_convert
<
bhalf_t
>
(
x
);
y
=
type_convert
<
bhalf_t
>
(
x
);
}
}
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
half_t
>
(
bhalf_t
&
y
,
const
half_t
&
x
)
const
{
y
=
type_convert
<
bhalf_t
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
__host__
__device__
void
operator
()
<
int8_t
,
int8_t
>
(
int8_t
&
y
,
const
int8_t
&
x
)
const
{
{
...
@@ -86,6 +92,23 @@ struct UnaryConvert
...
@@ -86,6 +92,23 @@ struct UnaryConvert
}
}
};
};
struct
ConvertBF16RTN
{
// convert to bf16 using round to nearest (rtn)
template
<
typename
Y
,
typename
X
>
__host__
__device__
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
{
// check Y datatype
static_assert
(
is_same
<
Y
,
bhalf_t
>::
value
,
"Data type is not supported by this operation!"
);
// check X datatype
static_assert
(
is_same
<
X
,
float
>::
value
||
is_same
<
X
,
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
bf16_convert_rtn
<
Y
>
(
x
);
}
};
struct
Scale
struct
Scale
{
{
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
__host__
__device__
Scale
(
float
scale
)
:
scale_
(
scale
)
{}
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
0d6da13c
...
@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
...
@@ -587,4 +587,52 @@ struct OffsettedBlockToCTileMap
index_t
block_start_
;
index_t
block_start_
;
};
};
/**
* @brief Simple tile mapping which creates 3D grid of block of threads.
*
* @paragraph Description
* This Block-to-C-tile-map creates a 3D grid (n_blocks, m_blocks, z_blocks) of thread
* blocks. The first 2D are regular 2D tiles created by division of output GEMM
* dimenions by corresponding tile size. The third dimension (Z) is a k-split dimension,
* which denotes the number of blocks we use to divide work on GEMM K dimension onto.
*
* @tparam MPerBlock Output block tile size in M dimension.
* @tparam NPerBlock Output block tile size in N dimension.
*/
template
<
index_t
MPerBlock
,
index_t
NPerBlock
>
struct
BlockToCTileMap_3DGrid_KSplit
{
__host__
__device__
BlockToCTileMap_3DGrid_KSplit
()
=
default
;
__host__
__device__
constexpr
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
k_split
)
const
{
// Create 3D grid
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
std
::
make_tuple
(
N0
,
M0
,
k_split
);
}
template
<
typename
TopIdx
>
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
)
const
{
return
make_tuple
(
blockIdx
.
z
,
blockIdx
.
y
,
blockIdx
.
x
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
};
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_bias_add_reduce_xdl_cshuffle_v1.hpp
View file @
0d6da13c
...
@@ -66,7 +66,8 @@ __global__ void
...
@@ -66,7 +66,8 @@ __global__ void
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
0d6da13c
...
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -96,7 +96,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX
&& defined(__gfx90a__)
#if CK_WORKAROUND_DENORM_FIX
using
ABDataTypeAdjusted
=
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
#else
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment