Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d6ea89ec
Commit
d6ea89ec
authored
Oct 16, 2024
by
Mirza Halilcevic
Browse files
Add descriptor and RTC workarounds for batched_gemm_multiple_d_gemm_multiple_d.
parent
d20c20a6
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
377 additions
and
43 deletions
+377
-43
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
...emm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
+9
-9
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
...device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
+4
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+359
-29
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
...tched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
+5
-5
No files found.
codegen/src/device_batched_gemm_multiple_d_gemm_multiple_d_operation_xdl_cshuffle.cpp
View file @
d6ea89ec
...
@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
...
@@ -331,7 +331,7 @@ std::vector<Operation_Xdl_CShuffle> Operation_Xdl_CShuffle::CreateOperations(
prob
.
N
,
prob
.
N
,
prob
.
K
,
prob
.
K
,
prob
.
O
,
prob
.
O
,
x
.
tile_desc
.
gemm0_m_per_block
,
x
.
tile_desc
.
gemm0
1
_m_per_block
,
x
.
tile_desc
.
gemm0_n_per_block
,
x
.
tile_desc
.
gemm0_n_per_block
,
x
.
tile_desc
.
gemm0_k_per_block
,
x
.
tile_desc
.
gemm0_k_per_block
,
x
.
tile_desc
.
gemm1_n_per_block
,
x
.
tile_desc
.
gemm1_n_per_block
,
...
@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -404,13 +404,13 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
std
::
unordered_map
<
std
::
string
,
std
::
string
>
values
=
{
{
"name"
,
{
"name"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
block_size
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)
+
std
::
to_string
(
this
->
tile_desc
.
ak1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
bk1
)
+
"_"
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
b1k1
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_Xdl_per_wave
)
+
"_"
+
...
@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -426,7 +426,7 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
MakeTuple
(
Transform
(
this
->
D1s
,
[](
auto
tensor
)
{
return
ToString
(
tensor
.
layout
);
}))},
{
"E1Layout"
,
ToString
(
this
->
E1
.
layout
)},
{
"E1Layout"
,
ToString
(
this
->
E1
.
layout
)},
{
"ADataType"
,
ToString
(
this
->
A0
.
element
)},
{
"A
0
DataType"
,
ToString
(
this
->
A0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"B0DataType"
,
ToString
(
this
->
B0
.
element
)},
{
"Acc0DataType"
,
ToString
(
this
->
acc_type
)},
{
"Acc0DataType"
,
ToString
(
this
->
acc_type
)},
{
"D0sDataType"
,
{
"D0sDataType"
,
...
@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
...
@@ -450,15 +450,15 @@ Solution Operation_Xdl_CShuffle::ToSolution() const
{
"PadGemm1N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_n
)},
{
"PadGemm1N"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_n
)},
{
"PadGemm1K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_k
)},
{
"PadGemm1K"
,
std
::
to_string
(
this
->
padding_desc
.
pad_gemm1_k
)},
{
"NumGemm0KPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemm
0
k_prefetch_stage
)},
{
"NumGemm0KPrefetchStage"
,
std
::
to_string
(
this
->
tile_desc
.
num_gemmk_prefetch_stage
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"BlockSize"
,
std
::
to_string
(
this
->
tile_desc
.
block_size
)},
{
"Gemm0MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_m_per_block
)},
{
"Gemm0MPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0
1
_m_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_n_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm0KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm0_k_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1NPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_n_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"Gemm1KPerBlock"
,
std
::
to_string
(
this
->
tile_desc
.
gemm1_k_per_block
)},
{
"A0K1"
,
std
::
to_string
(
this
->
tile_desc
.
a
0
k1
)},
{
"A0K1"
,
std
::
to_string
(
this
->
tile_desc
.
ak1
)},
{
"B0K1"
,
std
::
to_string
(
this
->
tile_desc
.
b
0
k1
)},
{
"B0K1"
,
std
::
to_string
(
this
->
tile_desc
.
bk1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"B1K1"
,
std
::
to_string
(
this
->
tile_desc
.
b1k1
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"MPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
m_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
{
"NPerXDL"
,
std
::
to_string
(
this
->
tile_desc
.
n_per_XDL
)},
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_multiple_d_gemm_multiple_d.hpp
View file @
d6ea89ec
...
@@ -3,8 +3,10 @@
...
@@ -3,8 +3,10 @@
#pragma once
#pragma once
#ifndef __HIPCC_RTC__
#include <iostream>
#include <iostream>
#include <vector>
#include <vector>
#endif
#include "device_base.hpp"
#include "device_base.hpp"
...
@@ -31,6 +33,7 @@ template <typename A0Layout,
...
@@ -31,6 +33,7 @@ template <typename A0Layout,
typename
CDE1ElementwiseOperation
>
typename
CDE1ElementwiseOperation
>
struct
DeviceBatchedGemmMultipleDGemmMultipleD
:
public
BaseOperator
struct
DeviceBatchedGemmMultipleDGemmMultipleD
:
public
BaseOperator
{
{
#ifndef __HIPCC_RTC__
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD0Tensor
=
D0sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
static
constexpr
index_t
NumD1Tensor
=
D1sDataType
::
Size
();
...
@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
...
@@ -65,6 +68,7 @@ struct DeviceBatchedGemmMultipleDGemmMultipleD : public BaseOperator
CDE1ElementwiseOperation
cde1_element_op
)
=
0
;
CDE1ElementwiseOperation
cde1_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d6ea89ec
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle_v1.hpp
View file @
d6ea89ec
...
@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -303,10 +303,10 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
return
false
;
return
false
;
}
}
if
(
!
block_2_e1tile_map
.
CheckValidity
(
e1_grid_desc_m_n
))
//
if(!block_2_e1tile_map.CheckValidity(e1_grid_desc_m_n))
{
//
{
return
false
;
//
return false;
}
//
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
return
true
;
return
true
;
...
@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
...
@@ -952,7 +952,7 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
else
else
{
{
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
static_for
<
0
,
acc0_thread_buf
.
Size
(),
1
>
{}(
[
&
](
auto
i
)
{
cde0_element_op
(
acc_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
[
&
](
auto
i
)
{
cde0_element_op
(
acc
0
_thread_buf
(
i
),
acc0_thread_buf
[
i
]);
});
}
}
// gemm1
// gemm1
{
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment