Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
96b0f78c
Commit
96b0f78c
authored
Sep 16, 2022
by
wangshaojie6
Browse files
clang-format
parent
97dcc7b2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
159 additions
and
145 deletions
+159
-145
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+45
-37
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
...e/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
+38
-37
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+13
-13
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
+38
-35
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
..._batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
+3
-2
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+19
-19
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
96b0f78c
...
...
@@ -371,7 +371,8 @@ int main(int argc, char* argv[])
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
100755 → 100644
View file @
96b0f78c
...
...
@@ -755,8 +755,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
...
...
@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
if
constexpr
(
MaskOutUpperTriangle
)
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
{
continue
;
}
...
...
@@ -818,24 +821,29 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
n_global
=
nstartgroup
+
n4_i
;
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
if
(
n_global
>
m_global
)
if
(
n_global
>
m_global
)
{
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
}
else
{
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
#else
// Acc0 elementwise Op
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
#else
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
acc_offset
),
acc_element_op
,
acc_thread_buf
[
acc_offset
]);
#endif
acc_thread_buf
(
acc_offset
),
acc_element_op
,
acc_thread_buf
[
acc_offset
]);
#endif
}
});
});
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
View file @
96b0f78c
...
...
@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
96b0f78c
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
View file @
96b0f78c
...
...
@@ -196,7 +196,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
,
...
...
@@ -227,7 +228,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
...
@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
pass
=
pass
&
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
pass
=
pass
&
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
if
(
do_log
)
{
...
...
@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_gs_ms_os_host_result : "
,
c_gs_ms_os_host_result
.
mData
,
","
)
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_gs_ms_os_device_result : "
,
c_gs_ms_os_device_result
.
mData
,
","
)
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_gs_ms_os_device_result : "
,
c_gs_ms_os_device_result
.
mData
,
","
)
<<
std
::
endl
;
}
}
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
View file @
96b0f78c
...
...
@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template
<
typename
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
:
public
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
<
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
:
public
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
<
Tuple
>
{
};
...
...
@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{
49
,
49
,
64
,
64
,
4
,
6
},
{
64
,
49
,
64
,
64
,
4
,
6
},
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
};
this
->
bench_
=
true
;
this
->
Run
();
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
View file @
96b0f78c
...
...
@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
ADataType
,
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
ALayout
,
B0Layout
,
B1Layout
,
CPermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
CPermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment