Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
321b6c8e
Commit
321b6c8e
authored
Jul 25, 2023
by
ltqin
Browse files
change enum to MaskUpperTringleFrom
parent
b4514459
Changes
41
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
96 additions
and
71 deletions
+96
-71
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
+9
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+9
-4
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+3
-3
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
...n_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+4
-4
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+1
-1
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+17
-16
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+1
-1
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
+4
-3
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
+4
-3
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
+2
-2
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
...m_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
+4
-3
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
...m_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
+4
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
321b6c8e
...
...
@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -628,7 +632,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -828,7 +832,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
321b6c8e
...
...
@@ -505,9 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -636,7 +640,7 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -836,7 +840,8 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
View file @
321b6c8e
...
...
@@ -401,9 +401,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -531,7 +535,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -697,7 +701,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
321b6c8e
...
...
@@ -407,9 +407,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
return
MaskOutUpperTrianglePredicate
{};
return
MaskUpperTringleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
@@ -537,7 +541,7 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
Transform
::
matrix_padder
.
PadN
,
MaskingSpec
=
=
MaskingSpecialization
::
Mask
OutUpperTriang
le
,
MaskingSpec
!
=
MaskingSpecialization
::
Mask
Disab
le
d
,
Deterministic
>
;
using
Block2CTileMap
=
OffsettedBlockToCTileMap
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
;
...
...
@@ -708,7 +712,8 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
type_convert
<
index_t
>
(
lse_grid_desc_m
.
GetElementSpaceSize
()));
// C0 mask
const
auto
c0_matrix_mask
=
C0MatrixMask
(
b_grid_desc_g_n_k
.
GetLength
(
I1
));
const
auto
c0_matrix_mask
=
C0MatrixMask
(
a_grid_desc_g_m_k
.
GetLength
(
I1
),
b_grid_desc_g_n_k
.
GetLength
(
I1
));
grid_size_
+=
grid_size_grp
;
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
321b6c8e
...
...
@@ -10,7 +10,7 @@ namespace device {
enum
struct
MaskingSpecialization
{
MaskDisabled
,
Mask
Out
UpperTri
a
ngle
,
MaskUpperTringle
FromTopLeft
,
MaskUpperTringleFromBottomRight
};
...
...
@@ -19,7 +19,7 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
switch
(
s
)
{
case
MaskingSpecialization
::
MaskDisabled
:
return
"MaskDisabled"
;
case
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
:
return
"Mask
Out
UpperTri
a
ngle"
;
case
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
:
return
"MaskUpperTringle
FromTopLeft
"
;
case
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
:
return
"MaskUpperTringleFromBottomRight"
;
default:
return
"Unrecognized specialization!"
;
...
...
@@ -40,7 +40,7 @@ struct MaskDisabledPredicate
}
};
struct
Mask
Out
UpperTri
a
nglePredicate
struct
MaskUpperTringle
FromTopLeft
Predicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
View file @
321b6c8e
...
...
@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
321b6c8e
...
...
@@ -57,7 +57,7 @@ template <typename ALayout,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
bool
Mask
Out
UpperTri
a
ngle
>
bool
MaskUpperTringle
FromTopLeft
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
Scale
,
PassThrough
,
PassThrough
,
Mask
Out
UpperTri
a
ngle
>>
MaskUpperTringle
FromTopLeft
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
Scale
,
PassThrough
,
PassThrough
,
Mask
Out
UpperTri
a
ngle
>
;
MaskUpperTringle
FromTopLeft
>
;
static
auto
GetInstances
()
{
...
...
@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
if
constexpr
(
Mask
Out
UpperTri
a
ngle
)
if
constexpr
(
MaskUpperTringle
FromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
321b6c8e
...
...
@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
321b6c8e
...
...
@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
1
,
1
,
1
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
{});
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
{});
}
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
321b6c8e
...
...
@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
1
,
1
,
1
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
{});
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
{});
}
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
321b6c8e
...
...
@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
1
,
1
,
1
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
{});
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
{});
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
321b6c8e
...
...
@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>>>&
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
1
,
1
,
1
,
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
{});
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
{});
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
321b6c8e
...
...
@@ -241,7 +241,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
});
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
321b6c8e
...
...
@@ -31,7 +31,7 @@ template <typename ADataType,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
bool
Mask
Out
UpperTri
a
ngle
>
bool
MaskUpperTringle
FromTopLeft
>
bool
profile_batched_gemm_softmax_gemm_impl
(
bool
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -197,7 +197,8 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
B1Layout
,
CLayout
,
...
...
@@ -210,7 +211,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
Mask
Out
UpperTri
a
ngle
>
;
MaskUpperTringle
FromTopLeft
>
;
// get device op instances
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
@@ -229,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
Mask
Out
UpperTri
a
ngle
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskUpperTringle
FromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
321b6c8e
...
...
@@ -219,7 +219,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
MaskingSpec
==
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
View file @
321b6c8e
...
...
@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
using
MaskUpperTringleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
Mask
Out
UpperTri
a
ngle_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
MaskUpperTringle
FromTopLeft
_t
>
>
;
// clang-format on
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
View file @
321b6c8e
...
...
@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
using
MaskUpperTringleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
Mask
Out
UpperTri
a
ngle_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskUpperTringle
FromTopLeft
_t
>
>
;
// clang-format on
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
View file @
321b6c8e
...
...
@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
;
// Mask
Out
UpperTri
a
ngle
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
;
// MaskUpperTringle
FromTopLeft
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
...
...
@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
Mask
Out
UpperTri
a
ngle
>
;
// Mask
Out
UpperTri
a
ngle
MaskingSpecialization
::
MaskUpperTringle
FromTopLeft
>
;
// MaskUpperTringle
FromTopLeft
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
View file @
321b6c8e
...
...
@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
using
MaskUpperTringleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
Mask
Out
UpperTri
a
ngle_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskUpperTringle
FromTopLeft
_t
>
>
;
// clang-format on
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_fp16.cpp
View file @
321b6c8e
...
...
@@ -15,13 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskOutUpperTriangle_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
;
using
MaskUpperTringleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
Mask
Out
UpperTri
a
ngle_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskUpperTringle
FromTopLeft
_t
>
>
;
// clang-format on
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment