Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6f774178
Commit
6f774178
authored
Jan 18, 2023
by
root
Browse files
pass device arrays as seperate args
parent
715e8dd2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
124 additions
and
26 deletions
+124
-26
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+122
-24
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+2
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
6f774178
...
@@ -23,6 +23,9 @@ namespace tensor_operation {
...
@@ -23,6 +23,9 @@ namespace tensor_operation {
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
GemmDesc
,
typename
GemmDesc
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -32,23 +35,46 @@ __global__ void
...
@@ -32,23 +35,46 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
kernel_grouped_gemm_xdl
(
const
index_t
group_count
,
#if 0
const
AElementwiseOperation
a_element_op
,
const void CK_CONSTANT_ADDRESS_SPACE* a_ptr,
const
BElementwiseOperation
b_element_op
,
const void CK_CONSTANT_ADDRESS_SPACE* b_ptr,
const
CDEElementwiseOperation
c_element_op
)
const void CK_CONSTANT_ADDRESS_SPACE* ds_ptr,
const void CK_CONSTANT_ADDRESS_SPACE* e_ptr,
#endif
const
ADataType
**
a_ptr_
,
const
BDataType
**
b_ptr_
,
const
typename
GridwiseGemm
::
DsGridPointer
*
ds_ptr_
,
EDataType
**
e_ptr_
,
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_ptr_
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
block_id
=
get_block_1d_id
();
const
index_t
block_id
=
get_block_1d_id
();
#if 0
const auto a_ptr_ =
static_cast<ADataType* const*>(cast_pointer_to_generic_address_space(a_ptr));
const auto b_ptr_ =
static_cast<BDataType* const*>(cast_pointer_to_generic_address_space(b_ptr));
const auto ds_ptr_ = static_cast<typename GridwiseGemm::DsGridPointer const*>(
cast_pointer_to_generic_address_space(ds_ptr));
const auto e_ptr_ =
static_cast<EDataType* const*>(cast_pointer_to_generic_address_space(e_ptr));
#endif
const
auto
gemm_desc_ptr
=
const
auto
gemm_desc_ptr
=
reinterpret
_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_
const
));
static
_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_
ptr_
));
index_t
left
=
0
;
index_t
left
=
0
;
index_t
right
=
group_count
;
index_t
right
=
group_count
;
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
index_t
group_id
=
index_t
((
left
+
right
)
/
2
);
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
BlockStart_
&&
while
((
!
(
block_id
>=
gemm_desc_ptr
[
group_id
].
BlockStart_
&&
block_id
<
gemm_desc_ptr
[
group_id
].
BlockEnd_
))
&&
block_id
<
gemm_desc_ptr
[
group_id
].
BlockEnd_
))
&&
left
<=
right
)
left
<=
right
)
...
@@ -64,11 +90,12 @@ __global__ void
...
@@ -64,11 +90,12 @@ __global__ void
group_id
=
index_t
((
left
+
right
)
/
2
);
group_id
=
index_t
((
left
+
right
)
/
2
);
}
}
#if 1
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
gemm_desc
_ptr
[
group_id
]
.
a_ptr_
,
a
_ptr
_
[
group_id
],
gemm_desc
_ptr
[
group_id
]
.
b_ptr_
,
b
_ptr
_
[
group_id
],
gemm_desc
_ptr
[
group_id
]
.
ds_ptr_
,
ds
_ptr
_
[
group_id
],
gemm_desc
_ptr
[
group_id
]
.
e_ptr_
,
e
_ptr
_
[
group_id
],
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
...
@@ -78,6 +105,8 @@ __global__ void
...
@@ -78,6 +105,8 @@ __global__ void
gemm_desc_ptr
[
group_id
].
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
ds_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
gemm_desc_ptr
[
group_id
].
block_2_etile_map_
);
gemm_desc_ptr
[
group_id
].
block_2_etile_map_
);
#endif
#else
#else
ignore
=
gemm_descs_const
;
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
group_count
;
...
@@ -323,12 +352,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -323,12 +352,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
struct
GemmBiasTransKernelArg
struct
GemmBiasTransKernelArg
{
{
// pointers
const
ADataType
*
a_ptr_
;
const
BDataType
*
b_ptr_
;
typename
GridwiseGemm
::
DsGridPointer
ds_ptr_
;
EDataType
*
e_ptr_
;
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
...
@@ -456,12 +479,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -456,12 +479,13 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
e_grid_desc_m_n
);
a_ptr_
.
push_back
(
static_cast
<
const
ADataType
*>
(
p_As
[
i
]));
b_ptr_
.
push_back
(
static_cast
<
const
BDataType
*>
(
p_Bs
[
i
]));
ds_ptr_
.
push_back
(
p_ds_grid
);
e_ptr_
.
push_back
(
static_cast
<
EDataType
*>
(
p_Es
[
i
]));
gemm_desc_kernel_arg_
.
push_back
(
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
static_cast
<
const
ADataType
*>
(
p_As
[
i
]),
GemmBiasTransKernelArg
{
a_grid_desc_m_k
,
static_cast
<
const
BDataType
*>
(
p_Bs
[
i
]),
p_ds_grid
,
static_cast
<
EDataType
*>
(
p_Es
[
i
]),
a_grid_desc_m_k
,
b_grid_desc_n_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
e_grid_desc_m_n
,
...
@@ -484,6 +508,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -484,6 +508,11 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
c_element_op_
;
CDEElementwiseOperation
c_element_op_
;
std
::
vector
<
const
ADataType
*>
a_ptr_
;
std
::
vector
<
const
BDataType
*>
b_ptr_
;
std
::
vector
<
typename
GridwiseGemm
::
DsGridPointer
>
ds_ptr_
;
std
::
vector
<
const
EDataType
*>
e_ptr_
;
std
::
vector
<
GemmBiasTransKernelArg
>
gemm_desc_kernel_arg_
;
std
::
vector
<
GemmBiasTransKernelArg
>
gemm_desc_kernel_arg_
;
index_t
grid_size_
;
index_t
grid_size_
;
...
@@ -542,16 +571,70 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -542,16 +571,70 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
}
}
}
}
int
wg_off
=
0
;
const
int
align
=
4
;
void
*
a_ptr_dev
=
static_cast
<
char
*>
(
arg
.
p_workspace_
)
+
wg_off
;
hipGetErrorString
(
hipMemcpy
(
a_ptr_dev
,
arg
.
a_ptr_
.
data
(),
arg
.
a_ptr_
.
size
()
*
sizeof
(
ADataType
*
),
hipMemcpyHostToDevice
));
auto
a_ptr
=
static_cast
<
const
ADataType
*
const
*>
(
arg
.
a_ptr_
.
data
());
wg_off
+=
arg
.
a_ptr_
.
size
()
*
sizeof
(
ADataType
*
);
wg_off
=
math
::
integer_least_multiple
(
wg_off
,
align
);
void
*
b_ptr_dev
=
static_cast
<
char
*>
(
arg
.
p_workspace_
)
+
wg_off
;
hipGetErrorString
(
hipMemcpy
(
b_ptr_dev
,
arg
.
b_ptr_
.
data
(),
arg
.
b_ptr_
.
size
()
*
sizeof
(
BDataType
*
),
hipMemcpyHostToDevice
));
wg_off
+=
arg
.
b_ptr_
.
size
()
*
sizeof
(
BDataType
*
);
wg_off
=
math
::
integer_least_multiple
(
wg_off
,
align
);
void
*
ds_ptr_dev
=
static_cast
<
char
*>
(
arg
.
p_workspace_
)
+
wg_off
;
hipGetErrorString
(
hipGetErrorString
(
hipMemcpy
(
arg
.
p_workspace_
,
hipMemcpy
(
ds_ptr_dev
,
arg
.
ds_ptr_
.
data
(),
arg
.
ds_ptr_
.
size
()
*
sizeof
(
typename
GridwiseGemm
::
DsGridPointer
),
hipMemcpyHostToDevice
));
wg_off
+=
arg
.
ds_ptr_
.
size
()
*
sizeof
(
typename
GridwiseGemm
::
DsGridPointer
);
wg_off
=
math
::
integer_least_multiple
(
wg_off
,
align
);
void
*
e_ptr_dev
=
static_cast
<
char
*>
(
arg
.
p_workspace_
)
+
wg_off
;
hipGetErrorString
(
hipMemcpy
(
e_ptr_dev
,
arg
.
e_ptr_
.
data
(),
arg
.
e_ptr_
.
size
()
*
sizeof
(
EDataType
*
),
hipMemcpyHostToDevice
));
wg_off
+=
arg
.
e_ptr_
.
size
()
*
sizeof
(
EDataType
*
);
wg_off
=
math
::
integer_least_multiple
(
wg_off
,
align
);
void
*
gemm_desc_dev
=
static_cast
<
char
*>
(
arg
.
p_workspace_
)
+
wg_off
;
hipGetErrorString
(
hipMemcpy
(
gemm_desc_dev
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmBiasTransKernelArg
),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmBiasTransKernelArg
),
hipMemcpyHostToDevice
));
hipMemcpyHostToDevice
));
wg_off
+=
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmBiasTransKernelArg
);
wg_off
=
math
::
integer_least_multiple
(
wg_off
,
align
);
float
ave_time
=
0
;
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
ADataType
,
BDataType
,
EDataType
,
GemmBiasTransKernelArg
,
GemmBiasTransKernelArg
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -564,7 +647,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -564,7 +647,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
dim3
(
arg
.
grid_size_
),
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
cast_pointer_to_constant_address_space
(
arg
.
p_workspace_
),
#if 0
cast_pointer_to_constant_address_space(a_ptr_dev),
cast_pointer_to_constant_address_space(b_ptr_dev),
cast_pointer_to_constant_address_space(ds_ptr_dev),
cast_pointer_to_constant_address_space(e_ptr_dev),
#endif
static_cast
<
const
ADataType
**>
(
a_ptr_dev
),
static_cast
<
const
BDataType
**>
(
b_ptr_dev
),
static_cast
<
const
typename
GridwiseGemm
::
DsGridPointer
*>
(
ds_ptr_dev
),
static_cast
<
EDataType
**>
(
e_ptr_dev
),
cast_pointer_to_constant_address_space
(
gemm_desc_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
...
@@ -670,7 +763,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
...
@@ -670,7 +763,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
{
{
return
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
sizeof
(
GemmBiasTransKernelArg
);
size_t
wg_size
=
dynamic_cast
<
const
Argument
*>
(
p_arg
)
->
group_count_
*
(
sizeof
(
GemmBiasTransKernelArg
)
+
sizeof
(
ADataType
*
)
+
sizeof
(
BDataType
*
)
+
sizeof
(
typename
GridwiseGemm
::
DsGridPointer
)
+
sizeof
(
EDataType
*
));
const
int
align
=
4
;
wg_size
=
math
::
integer_least_multiple
(
wg_size
,
align
);
return
wg_size
;
}
}
};
};
...
...
script/cmake-ck-dev.sh
View file @
6f774178
...
@@ -10,8 +10,8 @@ cmake
...
@@ -10,8 +10,8 @@ cmake
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_CXX_FLAGS
=
"-O3 -ftemplate-backtrace-limit=0 -gline-tables-only -save-temps=
$PWD
"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
O
N
\
-D
BUILD_DEV
=
O
FF
\
-D
GPU_TARGETS
=
"
gfx908;
gfx90a"
\
-D
GPU_TARGETS
=
"gfx90a"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment