Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
5af78ac2
Commit
5af78ac2
authored
Jul 26, 2023
by
ltqin
Browse files
fix triagle name
parent
4a653a5d
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
65 additions
and
65 deletions
+65
-65
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
...ice/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
+4
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
...pu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
+4
-4
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
...ck/tensor_operation/gpu/device/masking_specialization.hpp
+9
-9
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
...n_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
...nsor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
+4
-4
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+4
-4
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+2
-2
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+1
-1
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
...clude/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
+3
-3
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+1
-1
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
+3
-3
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
+3
-3
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
...mute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
+2
-2
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
...m_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
+3
-3
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_kloop_v2.hpp
View file @
5af78ac2
...
...
@@ -500,13 +500,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Kloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v1.hpp
View file @
5af78ac2
...
...
@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_bwd_xdl_cshuffle_qloop_v2.hpp
View file @
5af78ac2
...
...
@@ -505,13 +505,13 @@ struct DeviceGroupedMultiheadAttentionBackward_Qloop_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v1.hpp
View file @
5af78ac2
...
...
@@ -401,13 +401,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V1
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_mha_fwd_xdl_cshuffle_v2.hpp
View file @
5af78ac2
...
...
@@ -407,13 +407,13 @@ struct DeviceGroupedMultiheadAttentionForward_Xdl_CShuffle_V2
{
return
MaskDisabledPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
return
MaskUpperTringleFromTopLeftPredicate
{};
return
MaskUpperTri
a
ngleFromTopLeftPredicate
{};
}
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
)
else
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
)
{
return
MaskUpperTringleFromBottomRightPredicate
{};
return
MaskUpperTri
a
ngleFromBottomRightPredicate
{};
}
}
using
C0MatrixMask
=
C0MatrixMask_impl
<
decltype
(
make_MaskOutPredicate
())
>
;
...
...
include/ck/tensor_operation/gpu/device/masking_specialization.hpp
View file @
5af78ac2
...
...
@@ -10,8 +10,8 @@ namespace device {
enum
struct
MaskingSpecialization
{
MaskDisabled
,
MaskUpperTringleFromTopLeft
,
MaskUpperTringleFromBottomRight
MaskUpperTri
a
ngleFromTopLeft
,
MaskUpperTri
a
ngleFromBottomRight
};
inline
std
::
string
getMaskingSpecializationString
(
const
MaskingSpecialization
&
s
)
...
...
@@ -19,9 +19,9 @@ inline std::string getMaskingSpecializationString(const MaskingSpecialization& s
switch
(
s
)
{
case
MaskingSpecialization
::
MaskDisabled
:
return
"MaskDisabled"
;
case
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
:
return
"MaskUpperTringleFromTopLeft"
;
case
MaskingSpecialization
::
MaskUpperTringleFromBottomRight
:
return
"MaskUpperTringleFromBottomRight"
;
case
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
:
return
"MaskUpperTri
a
ngleFromTopLeft"
;
case
MaskingSpecialization
::
MaskUpperTri
a
ngleFromBottomRight
:
return
"MaskUpperTri
a
ngleFromBottomRight"
;
default:
return
"Unrecognized specialization!"
;
}
}
...
...
@@ -40,7 +40,7 @@ struct MaskDisabledPredicate
}
};
struct
MaskUpperTringleFromTopLeftPredicate
struct
MaskUpperTri
a
ngleFromTopLeftPredicate
{
__host__
__device__
constexpr
bool
operator
()(
index_t
m
,
index_t
n
)
const
{
return
n
>
m
;
}
...
...
@@ -50,9 +50,9 @@ struct MaskUpperTringleFromTopLeftPredicate
return
operator
()(
m
+
m_tile
-
1
,
n
);
}
};
struct
MaskUpperTringleFromBottomRightPredicate
struct
MaskUpperTri
a
ngleFromBottomRightPredicate
{
MaskUpperTringleFromBottomRightPredicate
()
:
offset_
(
0
)
{}
MaskUpperTri
a
ngleFromBottomRightPredicate
()
:
offset_
(
0
)
{}
__host__
__device__
void
SetOffset
(
const
index_t
offset
)
{
offset_
=
offset
;
}
__host__
__device__
constexpr
bool
operator
()(
index_t
m
,
index_t
n
)
const
{
...
...
@@ -77,7 +77,7 @@ struct C0MatrixMask_impl
C0MatrixMask_impl
(
index_t
MRaw
,
index_t
NRaw
)
:
NRaw_
(
NRaw
),
predicate_
(
MaskOutPredicate
{})
{
if
constexpr
(
std
::
is_same
<
MaskOutPredicate
,
MaskUpperTringleFromBottomRightPredicate
>::
value
)
MaskUpperTri
a
ngleFromBottomRightPredicate
>::
value
)
{
if
(
NRaw
>
MRaw
)
predicate_
.
SetOffset
(
NRaw
-
MRaw
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
View file @
5af78ac2
...
...
@@ -35,7 +35,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -77,7 +77,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -153,7 +153,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
@@ -169,7 +169,7 @@ struct DeviceOperationInstanceFactory<
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm.hpp
View file @
5af78ac2
...
...
@@ -57,7 +57,7 @@ template <typename ALayout,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
bool
MaskUpperTringleFromTopLeft
>
bool
MaskUpperTri
a
ngleFromTopLeft
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -72,7 +72,7 @@ struct DeviceOperationInstanceFactory<
Scale
,
PassThrough
,
PassThrough
,
MaskUpperTringleFromTopLeft
>>
MaskUpperTri
a
ngleFromTopLeft
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemm
<
ALayout
,
B0Layout
,
...
...
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
Scale
,
PassThrough
,
PassThrough
,
MaskUpperTringleFromTopLeft
>
;
MaskUpperTri
a
ngleFromTopLeft
>
;
static
auto
GetInstances
()
{
...
...
@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CLayout
,
Row
>
)
{
if
constexpr
(
MaskUpperTringleFromTopLeft
)
if
constexpr
(
MaskUpperTri
a
ngleFromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
5af78ac2
...
...
@@ -35,7 +35,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -77,7 +77,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
@@ -150,7 +150,7 @@ struct DeviceOperationInstanceFactory<
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
@@ -164,7 +164,7 @@ struct DeviceOperationInstanceFactory<
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
)
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
5af78ac2
...
...
@@ -83,7 +83,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -94,7 +94,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
{});
}
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
5af78ac2
...
...
@@ -85,7 +85,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -96,7 +96,7 @@ void add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
{});
}
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
View file @
5af78ac2
...
...
@@ -81,7 +81,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -92,7 +92,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
{});
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
5af78ac2
...
...
@@ -83,7 +83,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>>>&
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>>>&
instances
)
{
add_device_operation_instances
(
...
...
@@ -94,7 +94,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
1
,
1
,
1
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
{});
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
{});
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
...
...
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
5af78ac2
...
...
@@ -241,7 +241,7 @@ bool profile_batched_gemm_bias_softmax_gemm_permute_impl(bool do_verification,
});
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_impl.hpp
View file @
5af78ac2
...
...
@@ -31,7 +31,7 @@ template <typename ADataType,
typename
B0Layout
,
typename
B1Layout
,
typename
CLayout
,
bool
MaskUpperTringleFromTopLeft
>
bool
MaskUpperTri
a
ngleFromTopLeft
>
bool
profile_batched_gemm_softmax_gemm_impl
(
bool
do_verification
,
int
init_method
,
bool
do_log
,
...
...
@@ -211,7 +211,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
Acc0ElementOp
,
B1ElementOp
,
CElementOp
,
MaskUpperTringleFromTopLeft
>
;
MaskUpperTri
a
ngleFromTopLeft
>
;
// get device op instances
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
...
@@ -230,7 +230,7 @@ bool profile_batched_gemm_softmax_gemm_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
MaskUpperTringleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskUpperTri
a
ngleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
5af78ac2
...
...
@@ -219,7 +219,7 @@ bool profile_batched_gemm_softmax_gemm_permute_impl(bool do_verification,
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
&&
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_bf16.cpp
View file @
5af78ac2
...
...
@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskUpperTringleFromTopLeft_t
=
using
MaskUpperTri
a
ngleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
MaskUpperTringleFromTopLeft_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
MaskUpperTri
a
ngleFromTopLeft_t
>
>
;
// clang-format on
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_fp16.cpp
View file @
5af78ac2
...
...
@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskUpperTringleFromTopLeft_t
=
using
MaskUpperTri
a
ngleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskUpperTringleFromTopLeft_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
MaskUpperTri
a
ngleFromTopLeft_t
>
>
;
// clang-format on
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_bias_softmax_gemm_permute_util.hpp
View file @
5af78ac2
...
...
@@ -174,7 +174,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_FP16_M128_N128_K32_O128
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// MaskUpperTringleFromTopLeft
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
;
// MaskUpperTri
a
ngleFromTopLeft
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
...
...
@@ -321,7 +321,7 @@ struct DeviceInstanceWrapper_G2M1N1K1O1_TNTT_BF16_M128_N128_K32_O128
2
,
// CShuffleNXdlPerWavePerShuffle
S
<
1
,
32
,
1
,
8
>
,
// CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
8
,
// CShuffleBlockTransferScalarPerVector_NPerBlock
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
// MaskUpperTringleFromTopLeft
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
;
// MaskUpperTri
a
ngleFromTopLeft
bool
IsSupported
(
int
M
,
int
N
,
int
K
,
int
O
)
{
...
...
test/batched_gemm_softmax_gemm_permute/test_batched_gemm_softmax_gemm_permute_bf16.cpp
View file @
5af78ac2
...
...
@@ -15,14 +15,14 @@ using I2_t = ck::Number<2>;
using
MaskDisabled_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskDisabled
>
;
using
MaskUpperTringleFromTopLeft_t
=
using
MaskUpperTri
a
ngleFromTopLeft_t
=
ck
::
integral_constant
<
MaskingSpecialization
,
MaskingSpecialization
::
MaskUpperTringleFromTopLeft
>
;
MaskingSpecialization
::
MaskUpperTri
a
ngleFromTopLeft
>
;
// clang-format off
using
KernelTypes
=
::
testing
::
Types
<
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskDisabled_t
>
,
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskUpperTringleFromTopLeft_t
>
std
::
tuple
<
I2_t
,
I1_t
,
I1_t
,
I1_t
,
I1_t
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
MaskUpperTri
a
ngleFromTopLeft_t
>
>
;
// clang-format on
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment