Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
77a04d6a
"vscode:/vscode.git/clone" did not exist on "f600866a44240c29453ac8d3ff559ea34d3bd706"
Commit
77a04d6a
authored
Apr 04, 2024
by
Jing Zhang
Browse files
fixed register loads
parent
7d700bc0
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
63 deletions
+49
-63
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+34
-44
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+13
-17
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
77a04d6a
...
...
@@ -70,6 +70,9 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_KRow_
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow_
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
...
...
@@ -192,8 +195,8 @@ struct BlockwiseGemmWMMA
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
static_assert
(
AEnableLds
==
true
,
"only support EnableLds"
);
static_assert
(
BEnableLds
==
true
,
"only support EnableLds"
);
//
static_assert(AEnableLds == true, "only support EnableLds");
//
static_assert(BEnableLds == true, "only support EnableLds");
}
// transposed WMMA output C' = B' * A'
...
...
@@ -316,7 +319,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -326,7 +329,8 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -372,7 +376,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -380,7 +384,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -442,13 +446,7 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
using
AThreadCopyType
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
...
...
@@ -458,15 +456,8 @@ struct BlockwiseGemmWMMA
5
,
A_K1
,
A_K1
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
using
BThreadCopyType
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
...
...
@@ -476,10 +467,9 @@ struct BlockwiseGemmWMMA
5
,
B_K1
,
B_K1
>
;
};
typename
AThreadCopy
Selector
<
AEnableLds
>::
t
ype
a_thread_copy_
;
typename
BThreadCopy
Selector
<
BEnableLds
>::
t
ype
b_thread_copy_
;
AThreadCopy
T
ype
a_thread_copy_
;
BThreadCopy
T
ype
b_thread_copy_
;
};
#else
template
<
index_t
BlockSize
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
77a04d6a
...
...
@@ -94,8 +94,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
77a04d6a
...
...
@@ -333,8 +333,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
// constexpr auto A_KRow = I1;
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
...
...
@@ -374,8 +373,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_LPerWmma_BK1
constexpr
auto
B_K0
=
B0BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
B0BlockDesc_
{}.
GetLength
(
I2
);
// constexpr auto B_KRow = I1;
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
B0BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
...
...
@@ -412,8 +410,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
{
constexpr
index_t
A_L0
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I0
);
constexpr
index_t
A_L1
=
A1BlockDesc_AL0_M_AL1
{}.
GetLength
(
I2
);
// constexpr auto A_LRow = I1;
constexpr
auto
A_LRow
=
I2
;
constexpr
auto
A_LRow
=
I1
;
return
transform_tensor_descriptor
(
A1BlockDesc_AL0_M_AL1
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_L0
>
{},
A_LRow
)),
...
...
@@ -433,8 +430,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
// BL0_N_BL1 -> BL0_NRepeat_Nwaves_NPerWmma_BL1
constexpr
auto
B_L0
=
B1BlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_L1
=
B1BlockDesc_
{}.
GetLength
(
I2
);
// constexpr auto B_LRow = I1;
constexpr
auto
B_LRow
=
I2
;
constexpr
auto
B_LRow
=
I1
;
return
transform_tensor_descriptor
(
B1BlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_L0
>
{},
B_LRow
)),
...
...
@@ -1183,7 +1179,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
MRepeat
,
NRepeat
,
KPack
,
tru
e
,
fals
e
,
B1EnableLds
,
true
>
{
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
)};
...
...
@@ -1346,7 +1342,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds
();
//
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
block_sync_lds
();
...
...
@@ -1369,7 +1365,7 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma
block_sync_lds
();
//
blockwise_gemm1.Run(a1_thread_buf, b1_block_buf, acc1_thread_buf);
blockwise_gemm1
.
Run
(
a1_thread_buf
,
b1_block_buf
,
acc1_thread_buf
);
}
}
// end gemm1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment