Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e71afee2
Commit
e71afee2
authored
Jul 24, 2022
by
Jing Zhang
Browse files
add multiD support into batched_gemm_c_permute
parent
85978e02
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
204 additions
and
83 deletions
+204
-83
example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
...atched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
+3
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp
...or_operation/gpu/device/device_batched_gemm_c_permute.hpp
+7
-21
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp
...peration/gpu/device/device_batched_gemm_c_permute_xdl.hpp
+194
-62
No files found.
example/24_batched_gemm_c_permute/batched_gemm_c_permute_xdl_fp16.cpp
View file @
e71afee2
...
@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
...
@@ -178,14 +178,17 @@ int main(int argc, char* argv[])
// do GEM
// do GEM
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_device_buf
.
GetDeviceBuffer
()),
{},
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
static_cast
<
EDataType
*>
(
c_device_buf
.
GetDeviceBuffer
()),
M
,
M
,
N
,
N
,
K
,
K
,
stride_A
,
stride_A
,
stride_B
,
stride_B
,
{},
batch_stride_A
,
batch_stride_A
,
batch_stride_B
,
batch_stride_B
,
{},
batched_gemm_c_permute_desc
,
batched_gemm_c_permute_desc
,
batch_count
,
batch_count
,
a_element_op
,
a_element_op
,
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute.hpp
View file @
e71afee2
...
@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
...
@@ -16,26 +16,32 @@ struct BatchedGemmCPermuteDesc
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
D
E
Layout
,
typename
DLayout
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
typename
CDEElementwiseOperation
>
struct
DeviceBatchedGemmCPermute
:
public
BaseOperator
struct
DeviceBatchedGemmCPermute
:
public
BaseOperator
{
{
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
virtual
std
::
unique_ptr
<
BaseArgument
>
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
index_t
M
,
index_t
M
,
index_t
N
,
index_t
N
,
index_t
K
,
index_t
K
,
index_t
stride_A
,
index_t
stride_A
,
index_t
stride_B
,
index_t
stride_B
,
std
::
array
<
index_t
,
NumDTensor
>
stride_Ds
,
index_t
batch_stride_A
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_B
,
std
::
array
<
index_t
,
NumDTensor
>
batch_stride_Ds
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
BatchedGemmCPermuteDesc
batched_gemm_c_permute_desc
,
index_t
BatchCount
,
index_t
BatchCount
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
...
@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
...
@@ -45,26 +51,6 @@ struct DeviceBatchedGemmCPermute : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
>
using
DeviceBatchedGemmCPermutePtr
=
std
::
unique_ptr
<
DeviceBatchedGemmCPermute
<
ALayout
,
BLayout
,
DELayout
,
ADataType
,
BDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_c_permute_xdl.hpp
View file @
e71afee2
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment