Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
4ad62d7f
Commit
4ad62d7f
authored
May 24, 2022
by
Jing Zhang
Browse files
use CK_CONSTANT_ADDRESS_SPACE instead of global constant
parent
69add6ff
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
40 additions
and
38 deletions
+40
-38
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
+40
-38
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm_xdl.hpp
View file @
4ad62d7f
...
...
@@ -17,10 +17,6 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
#define CK_GEMM_DESCS_CONSTANT_BUFF_SIZE 1048576 // 1MB for 1000 gemm_descs
__constant__
static
char
gemm_descs_const_
[
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
];
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -33,7 +29,8 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdlops_v2r3
(
const
index_t
group_count
,
kernel_grouped_gemm_xdlops_v2r3
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
)
...
...
@@ -43,7 +40,8 @@ __global__ void
const
index_t
block_id
=
get_block_1d_id
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
gemm_descs_const_
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
index_t
group_id
=
0
;
for
(
index_t
i
=
0
;
i
<
group_count
;
i
++
)
...
...
@@ -465,18 +463,22 @@ struct DeviceGroupedGemmXdl
}
}
if
(
sizeof
(
GemmDescKernelArg
)
*
arg
.
gemm_desc_kernel_arg_
.
size
()
>
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
)
{
throw
std
::
runtime_error
(
"wrong! too many gemms"
);
}
KernelTimer
timer
;
timer
.
Start
();
void
*
gemm_descs_const_
;
hipGetErrorString
(
hipMalloc
(
&
gemm_descs_const_
,
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
)));
hipGetErrorString
(
hipMemcpyToSymbol
(
HIP_SYMBOL
(
gemm_descs_const_
),
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
0
,
hipMemcpyHostToDevice
));
hipMemcpy
(
gemm_descs_const_
,
arg
.
gemm_desc_kernel_arg_
.
data
(),
arg
.
gemm_desc_kernel_arg_
.
size
()
*
sizeof
(
GemmDescKernelArg
),
hipMemcpyHostToDevice
));
timer
.
End
();
std
::
cout
<<
"HipMemCpy time: "
<<
timer
.
GetElapsedTime
()
<<
std
::
endl
;
float
ave_time
=
0
;
...
...
@@ -492,15 +494,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
gemm_descs_const_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
else
{
...
...
@@ -514,15 +518,17 @@ struct DeviceGroupedGemmXdl
CElementwiseOperation
,
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
arg
.
grid_size_
),
dim3
(
BlockSize
),
0
,
cast_pointer_to_constant_address_space
(
gemm_descs_const_
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
}
return
ave_time
;
...
...
@@ -546,10 +552,6 @@ struct DeviceGroupedGemmXdl
if
(
ck
::
type_convert
<
ck
::
index_t
>
(
arg
.
gemm_desc_kernel_arg_
.
size
())
!=
arg
.
group_count_
)
return
false
;
if
(
sizeof
(
GemmDescKernelArg
)
*
arg
.
gemm_desc_kernel_arg_
.
size
()
>
CK_GEMM_DESCS_CONSTANT_BUFF_SIZE
)
return
false
;
return
true
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment