Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8d4b916e
"docs/source/en/api/image_processor.md" did not exist on "500a3ff9ef53fafc52a01e94e1d88b1f7c502928"
Commit
8d4b916e
authored
Mar 31, 2023
by
ltqin
Browse files
bias mask out of iner op
parent
a483948c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
170 additions
and
685 deletions
+170
-685
client_example/08_fused_attention/CMakeLists.txt
client_example/08_fused_attention/CMakeLists.txt
+4
-4
client_example/08_fused_attention/fused_attention_bias_mask_no_lib.cpp
...e/08_fused_attention/fused_attention_bias_mask_no_lib.cpp
+34
-2
client_example/08_fused_attention/fused_attention_mask_no_lib.cpp
...xample/08_fused_attention/fused_attention_mask_no_lib.cpp
+1
-1
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+0
-31
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
...n_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
+0
-190
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+129
-61
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
...nstance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
+0
-293
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
...ance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+0
-1
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
...ax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
+0
-100
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
...r/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
+1
-1
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
...ofiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
+1
-1
No files found.
client_example/08_fused_attention/CMakeLists.txt
View file @
8d4b916e
...
@@ -4,8 +4,8 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
...
@@ -4,8 +4,8 @@ target_link_libraries(client_fused_attention PRIVATE composable_kernel::device_o
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
add_executable
(
client_fused_attention_bias fused_attention_bias.cpp
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention_bias PRIVATE composable_kernel::device_operations
)
add_executable
(
client_fused_attention_bias_mask fused_attention_bias_mask.cpp
)
add_executable
(
client_fused_attention_bias_mask
_no_lib
fused_attention_bias_mask
_no_lib
.cpp
)
target_link_libraries
(
client_fused_attention_bias_mask PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention_bias_mask
_no_lib
PRIVATE composable_kernel::device_operations
)
add_executable
(
client_fused_attention_no_lib fused_attention_no_lib.cpp
)
add_executable
(
client_fused_attention_
mask_
no_lib fused_attention_
mask_
no_lib.cpp
)
target_link_libraries
(
client_fused_attention_no_lib PRIVATE composable_kernel::device_operations
)
target_link_libraries
(
client_fused_attention_
mask_
no_lib PRIVATE composable_kernel::device_operations
)
client_example/08_fused_attention/fused_attention_bias_mask.cpp
→
client_example/08_fused_attention/fused_attention_bias_mask
_no_lib
.cpp
View file @
8d4b916e
...
@@ -5,14 +5,46 @@
...
@@ -5,14 +5,46 @@
#include <vector>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
_general
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
struct
ScaleBiasMask
{
ScaleBiasMask
(
float
scale
,
float
mask_filter_value
)
:
scale_
(
scale
),
mask_filter_value_
(
mask_filter_value
)
{
}
// biased, masked
template
<
typename
Y
,
typename
X0
,
typename
X1
,
typename
X2
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x
,
const
X1
&
bias
,
const
X2
&
mask
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
ck
::
half_t
&
bias
,
const
int16_t
&
mask
)
const
{
float
filter_value
=
(
mask
==
1
?
0.0
f
:
mask_filter_value_
);
y
=
scale_
*
x
+
ck
::
type_convert
<
float
>
(
bias
)
+
filter_value
;
}
template
<
>
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
ck
::
half_t
&
bias
,
const
ck
::
half_t
&
mask
)
const
{
float
filter_value
=
(
mask
<
1.0
f
?
mask_filter_value_
:
0.0
f
);
y
=
scale_
*
x
+
ck
::
type_convert
<
float
>
(
bias
)
+
filter_value
;
}
const
float
scale_
;
const
float
mask_filter_value_
;
};
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
AElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Acc0ElementOp
=
ck
::
tensor_operation
::
element_wise
::
ScaleBiasMask
;
using
Acc0ElementOp
=
ScaleBiasMask
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
B1ElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
...
...
client_example/08_fused_attention/fused_attention_no_lib.cpp
→
client_example/08_fused_attention/fused_attention_
mask_
no_lib.cpp
View file @
8d4b916e
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#include <vector>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
_general
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
8d4b916e
...
@@ -389,37 +389,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
...
@@ -389,37 +389,6 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
}
};
};
struct
ScaleBiasMask
{
ScaleBiasMask
(
float
scale
,
float
mask_filter_value
)
:
scale_
(
scale
),
mask_filter_value_
(
mask_filter_value
)
{
}
// biased, masked
template
<
typename
Y
,
typename
X0
,
typename
X1
,
typename
X2
>
__host__
__device__
constexpr
void
operator
()(
Y
&
y
,
const
X0
&
x
,
const
X1
&
bias
,
const
X2
&
mask
)
const
;
template
<
>
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
half_t
&
bias
,
const
int16_t
&
mask
)
const
{
float
filter_value
=
(
mask
==
1
?
0.0
f
:
mask_filter_value_
);
y
=
scale_
*
x
+
ck
::
type_convert
<
float
>
(
bias
)
+
filter_value
;
}
template
<
>
__host__
__device__
constexpr
void
operator
()(
float
&
y
,
const
float
&
x
,
const
half_t
&
bias
,
const
half_t
&
mask
)
const
{
float
filter_value
=
(
mask
<
1.0
f
?
mask_filter_value_
:
0.0
f
);
y
=
scale_
*
x
+
ck
::
type_convert
<
float
>
(
bias
)
+
filter_value
;
}
const
float
scale_
;
const
float
mask_filter_value_
;
};
}
// namespace element_wise
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace tensor_operation
}
// namespace ck
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_bias_softmax_gemm_permute.hpp
deleted
100644 → 0
View file @
a483948c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpec
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpec
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
&&
Acc0BiasDataType
::
Size
()
==
1
&&
is_same_v
<
tuple_element_t
<
0
,
Acc0BiasDataType
>
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_bias_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
8d4b916e
...
@@ -11,13 +11,99 @@
...
@@ -11,13 +11,99 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
namespace
instance
{
namespace
instance
{
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -38,7 +124,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
...
@@ -38,7 +124,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_
f16_f16_f16_f16_
gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_
mutiple_d_
softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -59,7 +145,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
...
@@ -59,7 +145,7 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_gemm_m
asking
_softmax_gemm_permute_xdl_cshuffle_
bf16_bf16_bf16_bf16_
gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_m
utiple_d
_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -80,7 +166,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
...
@@ -80,7 +166,7 @@ void add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_
bf16_bf16_bf16_bf16_
gmk_gnk_gno_gmo_instances
(
void
add_device_batched_gemm_
mutiple_d_
softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
...
@@ -100,85 +186,67 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
...
@@ -100,85 +186,67 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
template
<
index_t
NumDimG
,
template
<
typename
ADataType
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
1
,
NumDimM
,
1
,
NumDimN
,
1
,
NumDimK
,
1
,
NumDimO
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ck
::
Tuple
<>
,
Acc0BiasDataType
,
ck
::
Tuple
<>
,
Acc1BiasDataType
,
PassThrough
,
AElementwiseOperation
,
PassThrough
,
B0ElementwiseOperation
,
Scale
,
C0DEElementwiseOperation
,
PassThrough
,
B1ElementwiseOperation
,
PassThrough
,
C1DEElementwiseOperation
,
MaskingSpec
>>
MaskingSpec
>>
{
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
1
,
NumDimM
,
1
,
NumDimN
,
1
,
NumDimK
,
1
,
NumDimO
,
ADataType
,
ADataType
,
B0DataType
,
B0DataType
,
B1DataType
,
B1DataType
,
CDataType
,
CDataType
,
ck
::
Tuple
<>
,
Acc0BiasDataType
,
ck
::
Tuple
<>
,
Acc1BiasDataType
,
PassThrough
,
AElementwiseOperation
,
PassThrough
,
B0ElementwiseOperation
,
Scale
,
C0DEElementwiseOperation
,
PassThrough
,
B1ElementwiseOperation
,
PassThrough
,
C1DEElementwiseOperation
,
MaskingSpec
>
;
MaskingSpec
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
B0DataType
,
half_t
>
&&
op_ptrs
);
is_same_v
<
B1DataType
,
half_t
>
&&
is_same_v
<
CDataType
,
half_t
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
}
// namespace instance
}
// namespace instance
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute_general.hpp
deleted
100644 → 0
View file @
a483948c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
element_wise
::
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<
BF16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleAdd
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
typename
ADataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
CDataType
,
typename
Acc0BiasDataType
,
typename
Acc1BiasDataType
,
typename
AElementwiseOperation
,
typename
B0ElementwiseOperation
,
typename
C0DEElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
C1DEElementwiseOperation
,
MaskingSpecialization
MaskingSpec
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>>
{
using
DeviceOp
=
DeviceBatchedGemmSoftmaxGemmPermute
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
ADataType
,
B0DataType
,
B1DataType
,
CDataType
,
Acc0BiasDataType
,
Acc1BiasDataType
,
AElementwiseOperation
,
B0ElementwiseOperation
,
C0DEElementwiseOperation
,
B1ElementwiseOperation
,
C1DEElementwiseOperation
,
MaskingSpec
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
8d4b916e
...
@@ -3,6 +3,5 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
...
@@ -3,6 +3,5 @@ add_instance_library(device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instance.cpp
deleted
100644 → 0
View file @
a483948c
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_multiple_d_softmax_gemm_permute_xdl_cshuffle_half_gmk_gnk_gno_gmo_instance.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ScaleBiasMask
=
ck
::
tensor_operation
::
element_wise
::
ScaleBiasMask
;
// f16 ScaleBiasMask masking
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
2
,
1
,
1
,
1
,
1
,
F16
,
F32
,
ck
::
Tuple
<
F16
,
F16
>
,
ScaleBiasMask
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
}
// f16 ScaleBiasMask disable masking
void
add_device_batched_gemm_mutiple_d_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
F16
,
F16
,
F16
,
F16
,
ck
::
Tuple
<
F16
,
F16
>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
ScaleBiasMask
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_bias_softmax_gemm_permute_xdl_cshuffle_gmk_gnk_gno_gmo_instances
<
2
,
1
,
1
,
1
,
1
,
F16
,
F32
,
ck
::
Tuple
<
F16
,
F16
>
,
ScaleBiasMask
,
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_batched_gemm_bias_softmax_gemm_permute_impl.hpp
View file @
8d4b916e
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
_general
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
...
...
profiler/include/profiler/profile_batched_gemm_softmax_gemm_permute_impl.hpp
View file @
8d4b916e
...
@@ -10,7 +10,7 @@
...
@@ -10,7 +10,7 @@
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm_permute.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute
_general
.hpp"
#include "ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/device_memory.hpp"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment