Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
96b0f78c
"...composable_kernel-1.git" did not exist on "61faf02b52df0ed4e15d74b999fe230435d58717"
Commit
96b0f78c
authored
Sep 16, 2022
by
wangshaojie6
Browse files
clang-format
parent
97dcc7b2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
159 additions
and
145 deletions
+159
-145
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
...mm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
+3
-2
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
...id/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
+45
-37
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
...e/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
+38
-37
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
...xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
+13
-13
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
+38
-35
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
..._batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
+3
-2
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
..._batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
+19
-19
No files found.
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_xdl_fp16.cpp
View file @
96b0f78c
...
@@ -370,8 +370,9 @@ int main(int argc, char* argv[])
...
@@ -370,8 +370,9 @@ int main(int argc, char* argv[])
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// mask out upper triangle
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
});
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp
100755 → 100644
View file @
96b0f78c
...
@@ -755,13 +755,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -755,13 +755,13 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
running_max_new
=
NumericLimits
<
FloatGemmAcc
>::
Lowest
();
// decoder lower triangular mask
// decoder lower triangular mask
const
auto
thread_cluster_idx
=
const
auto
thread_cluster_idx
=
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
threadid_to_m_n_thread_cluster_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
get_thread_local_1d_id
()));
make_multi_index
(
get_thread_local_1d_id
()));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
MPerRepeat
=
MPerBlock
/
MXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
NPerRepeat
=
NPerBlock
/
NXdlPerWave
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
const
index_t
mstart
=
m_block_data_idx_on_grid
+
thread_m_cluster_id
;
// gemm1 K loop
// gemm1 K loop
index_t
gemm1_k_block_outer_index
=
0
;
index_t
gemm1_k_block_outer_index
=
0
;
...
@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -769,8 +769,11 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
if
constexpr
(
MaskOutUpperTriangle
)
if
constexpr
(
MaskOutUpperTriangle
)
{
{
auto
gemm0_n_block_idx
=
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
auto
gemm0_n_block_idx
=
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
__builtin_amdgcn_readfirstlane
(
gemm1_k_block_outer_index
*
NPerBlock
);
if
((
m_block_data_idx_on_grid
<
gemm0_n_block_idx
)
&&
((
m_block_data_idx_on_grid
+
MPerBlock
-
1
)
<
(
gemm0_n_block_idx
+
NPerBlock
-
1
)))
{
{
continue
;
continue
;
}
}
...
@@ -807,40 +810,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -807,40 +810,45 @@ struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
}
}
else
else
{
{
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
const
index_t
nstart
=
gemm1_k_block_outer_index
*
NPerBlock
;
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
static_for
<
0
,
m0
,
1
>
{}([
&
](
auto
m0_i
)
{
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
m_global
=
mstart
+
m0_i
*
MPerRepeat
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
const
index_t
acc_idx_m0
=
m0_i
*
n0
*
n2
*
n4
;
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n0_i
)
{
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// constexpr auto nrepeat_i = n0_i * NPerRepeat;
// const index_t nstartxdl = nstart + nrepeat_i;
// const index_t nstartxdl = nstart + nrepeat_i;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
nstartxdl
=
nstart
+
n0_i
*
NPerRepeat
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
const
index_t
acc_idx_n0
=
acc_idx_m0
+
n0_i
*
n2
*
n4
;
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
static_for
<
0
,
n2
,
1
>
{}([
&
](
auto
n2_i
)
{
const
index_t
nstartgroup
=
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
const
index_t
nstartgroup
=
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
nstartxdl
+
thread_n_cluster_id
*
n4
+
n2_i
*
AccN3
*
n4
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
index_t
acc_idx_n2
=
acc_idx_n0
+
n2_i
*
n4
;
const
index_t
n_global
=
nstartgroup
+
n4_i
;
static_for
<
0
,
n4
,
1
>
{}([
&
](
auto
n4_i
)
{
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
const
index_t
n_global
=
nstartgroup
+
n4_i
;
if
(
n_global
>
m_global
)
const
auto
acc_offset
=
Number
<
acc_idx_n2
+
n4_i
>
{};
{
if
(
n_global
>
m_global
)
acc_thread_buf
(
acc_offset
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
{
}
acc_thread_buf
(
acc_offset
)
=
else
-
ck
::
NumericLimits
<
float
>::
Infinity
();
{
}
// Acc0 elementwise Op
else
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
{
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
// Acc0 elementwise Op
#else
#if CK_WORKAROUND_SWDEV_XXXXXX_ATTN_KERNEL_CLANG_CANNOT_SCAVENGE_REGISTER
acc_element_op
(
acc_thread_buf
(
acc_offset
),
acc_thread_buf
[
acc_offset
]);
#else
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
ElementOpPredicatedResetNaNToMinusInf
<
PadN
>
{}.
Run
(
acc_thread_buf
(
acc_offset
),
acc_element_op
,
acc_thread_buf
[
acc_offset
]);
acc_thread_buf
(
acc_offset
),
#endif
acc_element_op
,
}
acc_thread_buf
[
acc_offset
]);
#endif
}
});
});
});
});
});
});
});
});
}
}
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
block_sync_lds
();
// wait for lds read in gemm0 blockwise gemm
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute.hpp
View file @
96b0f78c
...
@@ -25,18 +25,18 @@ using CPermuteNumDims_G_M_O =
...
@@ -25,18 +25,18 @@ using CPermuteNumDims_G_M_O =
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Col
,
Row
,
Row
,
CPermuteNumDims_G_M_O
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
B0Layout
,
typename
B0Layout
,
...
@@ -48,32 +48,32 @@ template <typename ALayout,
...
@@ -48,32 +48,32 @@ template <typename ALayout,
typename
CDataType
>
typename
CDataType
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>
PassThrough
>>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B0Layout
,
B0Layout
,
B1Layout
,
B1Layout
,
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_Gemm1N
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>
;
PassThrough
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
...
@@ -83,7 +83,8 @@ struct DeviceOperationInstanceFactory<
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
B0Layout
,
Col
>
&&
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
is_same_v
<
B1Layout
,
Row
>
&&
is_same_v
<
CPermuteNumDims_G_M_Gemm1N
,
CPermuteNumDims_G_M_O
>
)
{
{
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
op_ptrs
);
op_ptrs
);
...
...
library/src/tensor_operation_instance/gpu/batched_gemm_masking_scale_softmax_gemm_permute/device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
View file @
96b0f78c
...
@@ -27,7 +27,7 @@ using CPermuteNumDims_G_M_O =
...
@@ -27,7 +27,7 @@ using CPermuteNumDims_G_M_O =
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
S
<
2
,
1
,
1
>
;
// "using CLayout = Row" has been replaced by CPermuteNumDims_G_M_O
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
static
constexpr
auto
GemmPadded
=
...
@@ -61,18 +61,18 @@ using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f1
...
@@ -61,18 +61,18 @@ using device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f1
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
void
add_device_batched_gemm_masking_scale_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
Row
,
Col
,
Col
,
Row
,
Row
,
CPermuteNumDims_G_M_O
,
CPermuteNumDims_G_M_O
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Scale
,
Scale
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
PassThrough
>>>&
instances
)
{
{
add_device_operation_instances
(
add_device_operation_instances
(
instances
,
instances
,
...
...
profiler/include/profile_batched_gemm_masking_scale_softmax_gemm_permute_impl.hpp
View file @
96b0f78c
...
@@ -31,22 +31,22 @@ template <typename ADataType,
...
@@ -31,22 +31,22 @@ template <typename ADataType,
typename
B1Layout
,
typename
B1Layout
,
typename
CPermuteNumDims_G_M_O
>
typename
CPermuteNumDims_G_M_O
>
bool
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
(
bool
do_verification
,
bool
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
(
bool
do_verification
,
int
init_method
,
int
init_method
,
bool
do_log
,
bool
do_log
,
bool
time_kernel
,
bool
time_kernel
,
int
M
,
int
M
,
int
N
,
int
N
,
int
K
,
int
K
,
int
O
,
int
O
,
int
G0
,
int
G0
,
int
G1
,
int
G1
,
int
StrideA
=
-
1
,
int
StrideA
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB0
=
-
1
,
int
StrideB1
=
-
1
,
int
StrideB1
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideA
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB0
=
-
1
,
int
BatchStrideB1
=
-
1
,
int
BatchStrideB1
=
-
1
,
float
alpha
=
1.
f
)
float
alpha
=
1.
f
)
{
{
...
@@ -196,19 +196,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -196,19 +196,20 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
auto
b1_element_op
=
B1ElementOp
{};
auto
b1_element_op
=
B1ElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
DeviceOp
=
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
using
DeviceOp
=
B0Layout
,
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
ALayout
,
B1Layout
,
B0Layout
,
CPermuteNumDims_G_M_O
,
B1Layout
,
ADataType
,
CPermuteNumDims_G_M_O
,
B0DataType
,
ADataType
,
B1DataType
,
B0DataType
,
CDataType
,
B1DataType
,
AElementOp
,
CDataType
,
B0ElementOp
,
AElementOp
,
Acc0ElementOp
,
B0ElementOp
,
B1ElementOp
,
Acc0ElementOp
,
CElementOp
>
;
B1ElementOp
,
CElementOp
>
;
// get device op instances
// get device op instances
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
const
auto
op_ptrs
=
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
...
@@ -226,8 +227,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -226,8 +227,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
ref_gemm0_invoker
.
Run
(
ref_gemm0_argument
);
// mask out upper triangle
// mask out upper triangle
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
acc0_g_m_n
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
if
(
idx
[
1
]
<
idx
[
2
])
self
(
idx
)
=
-
ck
::
NumericLimits
<
float
>::
Infinity
();
});
});
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
auto
ref_softmax
=
ReferenceSoftmaxInstance
{};
...
@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -319,8 +321,8 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
{
{
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
c_gs_ms_os_device_buf
.
FromDevice
(
c_gs_ms_os_device_result
.
mData
.
data
());
pass
=
pass
&
pass
=
pass
&
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
ck
::
utils
::
check_err
(
c_gs_ms_os_device_result
.
mData
,
c_gs_ms_os_host_result
.
mData
);
c_gs_ms_os_host_result
.
mData
);
if
(
do_log
)
if
(
do_log
)
{
{
...
@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
...
@@ -333,8 +335,9 @@ bool profile_batched_gemm_masking_scale_softmax_gemm_permute_impl(bool do_verifi
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_gs_ms_os_host_result : "
,
c_gs_ms_os_host_result
.
mData
,
","
)
std
::
cout
<<
"c_gs_ms_os_host_result : "
,
c_gs_ms_os_host_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
LogRangeAsType
<
float
>
(
LogRangeAsType
<
float
>
(
std
::
cout
<<
"c_gs_ms_os_device_result : "
,
std
::
cout
<<
"c_gs_ms_os_device_result : "
,
c_gs_ms_os_device_result
.
mData
,
","
)
c_gs_ms_os_device_result
.
mData
,
","
)
<<
std
::
endl
;
<<
std
::
endl
;
}
}
}
}
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_fp16.cpp
View file @
96b0f78c
...
@@ -5,7 +5,8 @@
...
@@ -5,7 +5,8 @@
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
#include "test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp"
template
<
typename
Tuple
>
template
<
typename
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
:
public
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
<
Tuple
>
class
TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16
:
public
TestBatchedGemmMaskingScaleSoftmaxGemmPermute
<
Tuple
>
{
{
};
};
...
@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
...
@@ -158,7 +159,7 @@ TYPED_TEST(TestBatchedGemmMaskingScaleSoftmaxGemmPermuteFP16, AdhocTest)
{
49
,
49
,
64
,
64
,
4
,
6
},
{
49
,
49
,
64
,
64
,
4
,
6
},
{
64
,
49
,
64
,
64
,
4
,
6
},
{
64
,
49
,
64
,
64
,
4
,
6
},
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
1020
,
1020
,
64
,
128
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
{
576
,
576
,
64
,
64
,
4
,
6
},
};
};
this
->
bench_
=
true
;
this
->
bench_
=
true
;
this
->
Run
();
this
->
Run
();
...
...
test/batched_gemm_masking_scale_softmax_gemm_permute/test_batched_gemm_masking_scale_softmax_gemm_permute_util.hpp
View file @
96b0f78c
...
@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
...
@@ -42,15 +42,15 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
void
RunSingle
(
int
M
,
int
N
,
int
K
,
int
O
,
int
G0
,
int
G1
)
{
{
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
ADataType
,
bool
pass
=
ck
::
profiler
::
profile_batched_gemm_masking_scale_softmax_gemm_permute_impl
<
B0
DataType
,
A
DataType
,
B
1
DataType
,
B
0
DataType
,
C
DataType
,
B1
DataType
,
ALayout
,
CDataType
,
B0
Layout
,
A
Layout
,
B
1
Layout
,
B
0
Layout
,
CPermuteNumDims_G_M_O
>
(
B1Layout
,
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
CPermuteNumDims_G_M_O
>
(
verify_
,
1
,
false
,
bench_
,
M
,
N
,
K
,
O
,
G0
,
G1
);
EXPECT_TRUE
(
pass
);
EXPECT_TRUE
(
pass
);
}
}
...
@@ -59,12 +59,12 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
...
@@ -59,12 +59,12 @@ struct TestBatchedGemmMaskingScaleSoftmaxGemmPermute : public ::testing::Test
{
{
for
(
auto
lengths
:
this
->
lengths_
)
for
(
auto
lengths
:
this
->
lengths_
)
{
{
int
M
=
lengths
[
0
];
int
M
=
lengths
[
0
];
int
N
=
lengths
[
1
];
int
N
=
lengths
[
1
];
int
K
=
lengths
[
2
];
int
K
=
lengths
[
2
];
int
O
=
lengths
[
3
];
int
O
=
lengths
[
3
];
int
G0
=
lengths
[
4
];
int
G0
=
lengths
[
4
];
int
G1
=
lengths
[
5
];
int
G1
=
lengths
[
5
];
this
->
RunSingle
(
M
,
N
,
K
,
O
,
G0
,
G1
);
this
->
RunSingle
(
M
,
N
,
K
,
O
,
G0
,
G1
);
}
}
...
@@ -75,7 +75,7 @@ template <GemmSpecialization GemmSpec>
...
@@ -75,7 +75,7 @@ template <GemmSpecialization GemmSpec>
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
struct
DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
{
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
ALayout
=
Row
;
using
ALayout
=
Row
;
using
B0Layout
=
Col
;
using
B0Layout
=
Col
;
...
@@ -174,8 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -174,8 +174,8 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
K
,
K
,
O
,
O
,
0
,
// BatchCount
0
,
// BatchCount
{
0
,
0
,
M
,
O
},
// gs ms ns lengths
{
0
,
0
,
M
,
O
},
// gs ms ns lengths
{
0
,
O
,
0
,
1
},
// gs ms ns strides
{
0
,
O
,
0
,
1
},
// gs ms ns strides
0
,
// StrideA
0
,
// StrideA
0
,
// StrideB0
0
,
// StrideB0
0
,
// StrideB1
0
,
// StrideB1
...
@@ -184,7 +184,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
...
@@ -184,7 +184,7 @@ struct DeviceInstanceWrapper_TNTT_FP16_M128_N128_K32_O128
0
,
// BatchStrideB1
0
,
// BatchStrideB1
PassThrough
{},
// a_element_op
PassThrough
{},
// a_element_op
PassThrough
{},
// b0_element_op
PassThrough
{},
// b0_element_op
Scale
{
1.
f
},
// acc0_element_op
Scale
{
1.
f
},
// acc0_element_op
PassThrough
{},
// b1_element_op
PassThrough
{},
// b1_element_op
PassThrough
{});
// c_element_op
PassThrough
{});
// c_element_op
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment