Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0ce9cacf
Commit
0ce9cacf
authored
Jul 25, 2023
by
Adam Osewski
Browse files
Get back to use constant memory for gemm descriptors.
parent
3644f0ec
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
7 deletions
+8
-7
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
...mpl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
+8
-7
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle_tile_loop.hpp
View file @
0ce9cacf
...
@@ -53,7 +53,7 @@ __global__ void
...
@@ -53,7 +53,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
*
gemm_desc
,
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_desc
s_const
,
const
index_t
tile_count
,
const
index_t
tile_count
,
const
index_t
k_batch
)
const
index_t
k_batch
)
{
{
...
@@ -63,9 +63,10 @@ __global__ void
...
@@ -63,9 +63,10 @@ __global__ void
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
constexpr
index_t
shared_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
();
__shared__
uint8_t
p_shared
[
shared_size
];
__shared__
uint8_t
p_shared
[
shared_size
];
index_t
tile_id
=
get_block_1d_id
();
index_t
tile_id
=
get_block_1d_id
();
const
index_t
grid_size
=
get_grid_size
();
const
index_t
grid_size
=
get_grid_size
();
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
gemm_desc
);
const
auto
gemm_desc_ptr
=
reinterpret_cast
<
const
GemmDesc
*>
(
cast_pointer_to_generic_address_space
(
gemm_descs_const
));
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
MPerBlock
=
GridwiseGemm
::
GetMPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
static
constexpr
index_t
NPerBlock
=
GridwiseGemm
::
GetNPerBlock
();
...
@@ -144,7 +145,7 @@ __global__ void
...
@@ -144,7 +145,7 @@ __global__ void
}
}
#else
#else
ignore
=
gemm_desc
;
ignore
=
gemm_desc
s_const
;
ignore
=
tile_count
;
ignore
=
tile_count
;
ignore
=
k_batch
;
ignore
=
k_batch
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
...
@@ -502,7 +503,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -502,7 +503,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
const
auto
&
gemm_arg
=
arg
.
gemm_kernel_args_
[
i
];
if
(
stream_config
.
log_level_
>
0
)
if
(
stream_config
.
log_level_
>
0
)
{
{
gemm_arg
.
Print
();
//
gemm_arg.Print();
}
}
// Currently all groups use same kbatch value.
// Currently all groups use same kbatch value.
...
@@ -613,7 +614,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
...
@@ -613,7 +614,7 @@ struct DeviceGroupedGemmXdlSplitKCShuffle : public DeviceGroupedGemmSplitK<ALayo
dim3
(
ck
::
math
::
min
(
arg
.
grid_size_
,
max_occupancy_grid_size
)),
dim3
(
ck
::
math
::
min
(
arg
.
grid_size_
,
max_occupancy_grid_size
)),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
dev_gemm_args
,
cast_pointer_to_constant_address_space
(
dev_gemm_args
)
,
arg
.
grid_size_
,
arg
.
grid_size_
,
arg
.
K_BATCH
);
arg
.
K_BATCH
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment