Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f4208484
Commit
f4208484
authored
Jul 07, 2023
by
Adam Osewski
Browse files
Simplify kernel arguments.
Calculate descriptors & B2C maps on the device.
parent
5afba5aa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
178 additions
and
61 deletions
+178
-61
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+131
-52
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
...tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
+47
-9
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
View file @
f4208484
...
@@ -27,10 +27,10 @@ namespace device {
...
@@ -27,10 +27,10 @@ namespace device {
//
//
// @brief Entry point kernel for device-wide Grouped GEMM operation.
// @brief Entry point kernel for device-wide Grouped GEMM operation.
//
//
// @param[in] gemm_descs_const The pointer to the array of GEMM descriptor structures in
// @param[in] gemm_desc_const The pointer to the array of GEMM descriptor structures.
// constant memory.
// @param[in] tile_count The overall number of output tiles we divided all groups
// @param[in] tile_count The overall number of output tiles we divided all groups
// into.
// into.
// @param[in] k_batch The number of batches we split the K dimension into.
//
//
// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
// @tparam GridwiseGemm The specific GridwiseGEMM algorithm implementation.
// @tparam GemmDesc The structure holding all necessary descriptors and other
// @tparam GemmDesc The structure holding all necessary descriptors and other
...
@@ -43,24 +43,27 @@ namespace device {
...
@@ -43,24 +43,27 @@ namespace device {
//
//
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
typename
GemmDesc
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl_splitk
(
const
void
*
gemm_desc_const
,
const
index_t
tile_count
)
const
index_t
tile_count
,
const
index_t
k_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
index_t
tile_id
=
get_block_1d_id
();
index_t
tile_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
const
index_t
grid_size
=
get_grid_size
();
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
gemm_desc_const
);
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
group_id
=
0
;
index_t
group_id
=
0
;
index_t
offset
=
0
;
index_t
offset
=
0
;
...
@@ -68,26 +71,67 @@ __global__ void
...
@@ -68,26 +71,67 @@ __global__ void
while
(
tile_id
<
tile_count
)
while
(
tile_id
<
tile_count
)
{
{
// Find corresponding GEMM group for out tile
// Find corresponding GEMM group for out tile
while
(
!
(
tile_id
>=
gemm_desc_ptr
[
group_id
].
block_start
_
&&
while
(
!
(
tile_id
>=
gemm_desc_ptr
[
group_id
].
block_start
&&
tile_id
<
gemm_desc_ptr
[
group_id
].
block_end
_
))
tile_id
<
gemm_desc_ptr
[
group_id
].
block_end
))
{
{
offset
+=
gemm_desc_ptr
[
group_id
].
block_end
_
-
gemm_desc_ptr
[
group_id
].
block_start
_
;
offset
+=
gemm_desc_ptr
[
group_id
].
block_end
-
gemm_desc_ptr
[
group_id
].
block_start
;
group_id
++
;
group_id
++
;
}
}
LocalBlockToCTileMap
<
typename
GemmDesc
::
B2CType
>
local_b2c
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
FloatA
*>
(
gemm_desc_ptr
[
group_id
].
p_a_grid
);
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
tile_id
-
offset
};
const
auto
p_b_grid
=
reinterpret_cast
<
const
FloatB
*>
(
gemm_desc_ptr
[
group_id
].
p_b_grid
);
const
auto
p_c_grid
=
reinterpret_cast
<
FloatC
*>
(
gemm_desc_ptr
[
group_id
].
p_c_grid
);
const
auto
M
=
gemm_desc_ptr
[
group_id
].
M
;
const
auto
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
auto
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
StrideC
=
gemm_desc_ptr
[
group_id
].
StrideC
;
const
auto
MPadded
=
GridwiseGemm
::
CalculateMPadded
(
M
);
const
auto
NPadded
=
GridwiseGemm
::
CalculateNPadded
(
N
);
const
auto
KPadded
=
GridwiseGemm
::
CalculateKPadded
(
K
,
k_batch
);
const
auto
K0
=
GridwiseGemm
::
CalculateK0
(
K
,
k_batch
);
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
B2E_M01
=
8
;
using
CGridDesc_M_N
=
typename
GridwiseGemm
::
CGridDesc_M_N
;
using
Block2ETileMapKSplit
=
BlockToCTileMap_KSplit_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
;
const
auto
c_grid_desc_m_n
=
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
b2c_tile_map
=
Block2ETileMapKSplit
{
c_grid_desc_m_n
,
B2E_M01
,
k_batch
};
LocalBlockToCTileMap
<
Block2ETileMapKSplit
>
local_b2c
{
b2c_tile_map
,
tile_id
-
offset
};
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
local_b2c
);
p_a_grid
,
p_b_grid
,
p_c_grid
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
K0
,
k_batch
,
static_cast
<
void
*>
(
p_shared
),
local_b2c
);
tile_id
+=
grid_size
;
tile_id
+=
grid_size
;
}
}
#else
#else
ignore
=
gemm_desc
s
_const
;
ignore
=
gemm_desc_const
;
ignore
=
group_count
;
ignore
=
group_count
;
ignore
=
tile_count
;
ignore
=
tile_count
;
ignore
=
k_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -213,21 +257,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -213,21 +257,12 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
struct
GemmTransKernelArg
struct
GemmTransKernelArg
{
{
using
B2CType
=
Block2ETileMapKSplit
;
KernelArgument
karg
;
index_t
block_start
,
block_end
;
KernelArgument
karg_
;
Block2ETileMapKSplit
block_2_ctile_map_
;
index_t
block_start_
,
block_end_
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
()
=
default
;
GemmTransKernelArg
(
KernelArgument
&&
karg
,
GemmTransKernelArg
(
KernelArgument
&&
karg_
,
index_t
block_start_
,
index_t
block_end_
)
Block2ETileMapKSplit
&&
b2c_map
,
:
karg
{
karg_
},
block_start
{
block_start_
},
block_end
{
block_end_
}
index_t
block_start
,
index_t
block_end
)
:
karg_
{
karg
},
block_2_ctile_map_
{
b2c_map
},
block_start_
{
block_start
},
block_end_
{
block_end
}
{
{
}
}
};
};
...
@@ -265,7 +300,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -265,7 +300,6 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
}
}
gemm_kernel_args_
.
reserve
(
group_count_
);
gemm_kernel_args_
.
reserve
(
group_count_
);
skipped_group_count_
=
0
;
skipped_group_count_
=
0
;
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_descs
.
size
();
++
i
)
...
@@ -314,8 +348,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -314,8 +348,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
k0
,
k0
,
K_BATCH
};
K_BATCH
};
gemm_kernel_args_
.
emplace_back
(
gemm_kernel_args_
.
emplace_back
(
std
::
move
(
karg
),
block_start
,
block_end
);
std
::
move
(
karg
),
std
::
move
(
local_b2c_tile_map
),
block_start
,
block_end
);
}
}
}
}
...
@@ -332,7 +365,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -332,7 +365,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
gemm_kernel_args_
.
size
();
++
i
)
{
{
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg
_
;
auto
&
karg
=
gemm_kernel_args_
[
i
].
karg
;
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k_padded
=
GridwiseGemm
::
CalculateKPadded
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
const
index_t
k0
=
GridwiseGemm
::
CalculateK0
(
karg
.
K
,
K_BATCH
);
...
@@ -348,12 +381,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -348,12 +381,11 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
grid_size_
+=
grid_size_grp
;
grid_size_
+=
grid_size_grp
;
karg
.
KPadded
=
k_padded
;
karg
.
KPadded
=
k_padded
;
karg
.
K0
=
k0
;
karg
.
K0
=
k0
;
karg
.
k_batch
=
K_BATCH
;
karg
.
k_batch
=
K_BATCH
;
gemm_kernel_args_
[
i
].
block_2_ctile_map_
=
local_b2c_tile_map
;
gemm_kernel_args_
[
i
].
block_start
=
block_start
;
gemm_kernel_args_
[
i
].
block_start_
=
block_start
;
gemm_kernel_args_
[
i
].
block_end
=
block_end
;
gemm_kernel_args_
[
i
].
block_end_
=
block_end
;
}
}
}
}
...
@@ -377,15 +409,35 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -377,15 +409,35 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
// assume we want to have at most 2 waves per SIMD
// assume we want to have at most 2 waves per SIMD
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
8
,
BLOCK_WAVES
);
static
constexpr
int
CU_BLOCKS
=
math
::
integer_divide_floor
(
8
,
BLOCK_WAVES
);
struct
SimpleGemmArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
void
*
p_c_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
block_start
;
index_t
block_end
;
};
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
K0
;
std
::
vector
<
SimpleGemmArgument
>
simple_gemm_kernel_args
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg_
.
k_batch
>
1
;
simple_gemm_kernel_args
.
reserve
(
arg
.
gemm_kernel_args_
.
size
());
index_t
K0
=
arg
.
gemm_kernel_args_
[
0
].
karg
.
K0
;
bool
all_have_kbatch_gt_one
=
arg
.
gemm_kernel_args_
[
0
].
karg
.
k_batch
>
1
;
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
bool
all_have_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg
_
;
const
auto
&
karg
=
arg
.
gemm_kernel_args_
[
i
].
karg
;
// if(stream_config.log_level_ > 0)
// if(stream_config.log_level_ > 0)
// {
// {
// karg.Print();
// karg.Print();
...
@@ -419,16 +471,30 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -419,16 +471,30 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
std
::
ostringstream
err
;
std
::
ostringstream
err
;
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
err
<<
"Not all gemms have same kbatch value (=1 or >1)! "
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
"group ["
<<
i
<<
"], kbatch: "
<<
kbatch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg
_
.
k_batch
<<
", group [0], kbatch: "
<<
arg
.
gemm_kernel_args_
[
0
].
karg
.
k_batch
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
<<
" in "
<<
__FILE__
<<
":"
<<
__LINE__
<<
", in function: "
<<
__func__
;
throw
std
::
runtime_error
(
err
.
str
());
throw
std
::
runtime_error
(
err
.
str
());
}
}
simple_gemm_kernel_args
.
push_back
({
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
,
arg
.
gemm_kernel_args_
[
i
].
block_start
,
arg
.
gemm_kernel_args_
[
i
].
block_end
});
}
}
using
GemmArgumentType
=
SimpleGemmArgument
;
hip_check_error
(
hip_check_error
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_kernel_args
_
.
data
(),
simple_
gemm_kernel_args
.
data
(),
arg
.
gemm_kernel_args
_
.
size
()
*
sizeof
(
Gemm
TransKernelArg
),
simple_
gemm_kernel_args
.
size
()
*
sizeof
(
Gemm
ArgumentType
),
hipMemcpyHostToDevice
,
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
stream_config
.
stream_id_
));
...
@@ -439,7 +505,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -439,7 +505,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
for
(
const
auto
&
trans_arg
:
arg
.
gemm_kernel_args_
)
{
{
const
auto
&
karg
=
trans_arg
.
karg
_
;
const
auto
&
karg
=
trans_arg
.
karg
;
hip_check_error
(
hip_check_error
(
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
)));
hipMemset
(
karg
.
p_c_grid
,
0
,
karg
.
M
*
karg
.
N
*
sizeof
(
EDataType
)));
}
}
...
@@ -466,8 +532,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -466,8 +532,9 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
BLOCK_SUBSCRIPTION_FACTOR
),
BLOCK_SUBSCRIPTION_FACTOR
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
arg
.
p_workspace_
,
arg
.
grid_size_
);
arg
.
grid_size_
,
arg
.
gemm_kernel_args_
[
0
].
karg
.
k_batch
);
};
};
if
(
all_have_main_k0_block_loop
)
if
(
all_have_main_k0_block_loop
)
...
@@ -476,7 +543,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -476,7 +543,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
GemmArgumentType
,
ADataType
,
BDataType
,
EDataType
,
true
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
@@ -486,7 +556,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -486,7 +556,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
GemmArgumentType
,
ADataType
,
BDataType
,
EDataType
,
true
,
true
,
InMemoryDataOperationEnum
::
Set
>
;
InMemoryDataOperationEnum
::
Set
>
;
...
@@ -499,7 +572,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -499,7 +572,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
GemmArgumentType
,
ADataType
,
BDataType
,
EDataType
,
false
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
>
;
InMemoryDataOperationEnum
::
AtomicAdd
>
;
...
@@ -509,7 +585,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -509,7 +585,10 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
kernel_grouped_gemm_xdl_splitk
<
GridwiseGemm
,
GemmTransKernelArg
,
GemmArgumentType
,
ADataType
,
BDataType
,
EDataType
,
false
,
false
,
InMemoryDataOperationEnum
::
Set
>
;
InMemoryDataOperationEnum
::
Set
>
;
...
@@ -550,7 +629,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -550,7 +629,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
bool
supported
=
true
;
bool
supported
=
true
;
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_kernel_args_
.
size
();
++
i
)
{
{
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg
_
;
const
auto
&
a
=
arg
.
gemm_kernel_args_
[
i
].
karg
;
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
bool
group_arg_valid
=
GridwiseGemm
::
CheckValidity
(
a
);
if
(
not
group_arg_valid
)
if
(
not
group_arg_valid
)
{
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r4r2.hpp
View file @
f4208484
...
@@ -573,18 +573,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -573,18 +573,28 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
template
<
bool
HasMainKBlockLoop
,
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
__device__
static
void
Run
(
const
FloatAB
*
p_a_grid
,
const
FloatAB
*
p_b_grid
,
FloatC
*
p_c_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
StrideA
,
index_t
StrideB
,
index_t
StrideC
,
index_t
MPadded
,
index_t
NPadded
,
index_t
KPadded
,
index_t
K0
,
index_t
k_batch
,
void
*
__restrict__
p_shared_block
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
auto
a_b_k0_m_k1_grid_desc
=
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
MakeAGridDescriptor_KBatch_K0_M_K1
(
M
,
MPadded
,
K
,
StrideA
,
k_batch
,
K0
,
KPadded
);
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
b_b_k0_n_k1_grid_desc
=
const
auto
a_b_k0_m_k1_grid_desc
=
MakeAGridDescriptor_KBatch_K0_M_K1
(
MakeBGridDescriptor_KBatch_K0_N_K1
(
K
,
NPadded
,
N
,
StrideB
,
k_batch
,
K0
,
KPadded
);
karg
.
M
,
karg
.
MPadded
,
karg
.
K
,
karg
.
StrideA
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
const
auto
b_b_k0_n_k1_grid_desc
=
MakeBGridDescriptor_KBatch_K0_N_K1
(
karg
.
K
,
karg
.
NPadded
,
karg
.
N
,
karg
.
StrideB
,
karg
.
k_batch
,
karg
.
K0
,
karg
.
KPadded
);
const
auto
c_grid_desc_m_n
=
MakeCGridDescriptor_M_N
(
karg
.
M
,
karg
.
N
,
karg
.
StrideC
);
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
...
@@ -1056,6 +1066,34 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
...
@@ -1056,6 +1066,34 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2
}
}
}
}
template
<
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
Block2CTileMap
>
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared_block
,
const
Block2CTileMap
&
block_2_ctile_map
)
{
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
Block2CTileMap
>
(
karg
.
p_a_grid
,
karg
.
p_b_grid
,
karg
.
p_c_grid
,
karg
.
M
,
karg
.
N
,
karg
.
K
,
karg
.
StrideA
,
karg
.
StrideB
,
karg
.
StrideC
,
karg
.
MPadded
,
karg
.
NPadded
,
karg
.
KPadded
,
karg
.
K0
,
karg
.
k_batch
,
p_shared_block
,
block_2_ctile_map
);
}
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
static
constexpr
auto
GetNPerBlock
()
{
return
NPerBlock
;
}
static
std
::
string
GetTypeString
()
static
std
::
string
GetTypeString
()
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment