Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66f8e4bb
Commit
66f8e4bb
authored
Feb 14, 2024
by
Adam Osewski
Browse files
Support A/B/C elementwise ops.
parent
b9ab9f4b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
5 deletions
+18
-5
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
...u/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
+18
-5
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp
View file @
66f8e4bb
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -26,13 +26,20 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
>
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
typename
BElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
typename
CElementwiseOperation
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
)
kernel_grouped_gemm_xdl_splitk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
const
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
const
CElementwiseOperation
c_element_op
=
CElementwiseOperation
{})
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
...
...
@@ -64,10 +71,16 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
>(
gemm_desc_ptr
[
group_id
].
karg_
,
static_cast
<
void
*>
(
p_shared
),
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
);
gemm_desc_ptr
[
group_id
].
block_2_ctile_map_
,
a_element_op
,
b_element_op
,
c_element_op
);
#else
ignore
=
gemm_descs_const
;
ignore
=
group_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment