Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d1006d46
Commit
d1006d46
authored
Aug 30, 2023
by
Jing Zhang
Browse files
removed AddBias
parent
c4411860
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
24 additions
and
51 deletions
+24
-51
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
...e/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+7
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+0
-34
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+5
-4
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+0
-1
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
...brary/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
+6
-6
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
+1
-1
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
..._gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
+1
-1
No files found.
example/15_grouped_gemm/grouped_gemm_xdl_fixed_nk_bias_fp16.cpp
View file @
d1006d46
...
@@ -30,7 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
...
@@ -30,7 +30,7 @@ using Row = ck::tensor_layout::gemm::RowMajor;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
Bias
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
ADataType
=
F16
;
using
ADataType
=
F16
;
using
BDataType
=
F16
;
using
BDataType
=
F16
;
...
@@ -49,7 +49,7 @@ using ELayout = Row;
...
@@ -49,7 +49,7 @@ using ELayout = Row;
using
AElementOp
=
PassThrough
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CDEElementOp
=
Add
Bias
;
using
CDEElementOp
=
Add
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
;
...
...
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
View file @
d1006d46
...
@@ -36,6 +36,13 @@ struct Add
...
@@ -36,6 +36,13 @@ struct Add
y
=
x0
+
type_convert
<
half_t
>
(
x1
);
y
=
x0
+
type_convert
<
half_t
>
(
x1
);
};
};
template
<
>
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
float
&
x1
)
const
{
y
=
type_convert
<
half_t
>
(
x0
+
x1
);
};
template
<
>
template
<
>
__host__
__device__
constexpr
void
__host__
__device__
constexpr
void
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
operator
()
<
half_t
>
(
half_t
&
y
,
const
float
&
x0
,
const
half_t
&
x1
)
const
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
d1006d46
...
@@ -57,12 +57,6 @@ struct PassThrough
...
@@ -57,12 +57,6 @@ struct PassThrough
y
=
x
;
y
=
x
;
}
}
template
<
>
__host__
__device__
void
operator
()
<
half_t
,
float
>
(
half_t
&
y
,
const
float
&
x
)
const
{
y
=
type_convert
<
half_t
>
(
x
);
}
template
<
>
template
<
>
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
__host__
__device__
void
operator
()
<
bhalf_t
,
float
>
(
bhalf_t
&
y
,
const
float
&
x
)
const
{
{
...
@@ -126,34 +120,6 @@ struct PassThrough
...
@@ -126,34 +120,6 @@ struct PassThrough
}
}
};
};
struct
AddBias
{
template
<
typename
E
,
typename
C
,
typename
D0
>
__host__
__device__
void
operator
()(
E
&
e
,
const
C
&
c
,
const
D0
&
d0
)
const
;
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
float
,
float
>
(
ck
::
half_t
&
e
,
const
float
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
template
<
>
__host__
__device__
void
operator
()
<
ck
::
half_t
,
ck
::
half_t
,
float
>
(
ck
::
half_t
&
e
,
const
ck
::
half_t
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
template
<
>
__host__
__device__
void
operator
()
<
float
,
float
,
float
>
(
float
&
e
,
const
float
&
c
,
const
float
&
d0
)
const
{
e
=
c
+
d0
;
}
};
struct
UnaryConvert
struct
UnaryConvert
{
{
template
<
typename
Y
,
typename
X
>
template
<
typename
Y
,
typename
X
>
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
d1006d46
...
@@ -876,8 +876,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -876,8 +876,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
StrideE
,
const
index_t
StrideE
,
const
Block2ETileMap
&
block_2_etile_map
)
const
Block2ETileMap
&
block_2_etile_map
)
{
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
A
B
DataType
*>
(
p_a_grid_
);
const
auto
p_a_grid
=
reinterpret_cast
<
const
ADataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
A
BDataType
*>
(
p_b_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
BDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
// tensor descriptors for problem definiton
// tensor descriptors for problem definiton
...
@@ -902,8 +902,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -902,8 +902,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
const
auto
b_grid_desc_bk0_n_bk1
=
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
d1006d46
...
@@ -101,7 +101,6 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
...
@@ -101,7 +101,6 @@ using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
ScaleAdd
=
ck
::
tensor_operation
::
element_wise
::
ScaleAdd
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Gelu
=
ck
::
tensor_operation
::
element_wise
::
Gelu
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
Swish
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
AddBias
=
ck
::
tensor_operation
::
element_wise
::
AddBias
;
template
<
typename
Activation
>
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/grouped_gemm_bias.hpp
View file @
d1006d46
...
@@ -28,7 +28,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
...
@@ -28,7 +28,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instances(
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>>>&
instances
);
Add
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances
(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
...
@@ -41,7 +41,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
...
@@ -41,7 +41,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instances(
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>>>&
instances
);
Add
>>>&
instances
);
// fp32_output
// fp32_output
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances
(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances
(
...
@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
...
@@ -55,7 +55,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instances(
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>>>&
instances
);
Add
>>>&
instances
);
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances
(
void
add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedGemmFixedNK
<
Row
,
...
@@ -68,7 +68,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
...
@@ -68,7 +68,7 @@ void add_device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instances(
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>>>&
instances
);
Add
>>>&
instances
);
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
...
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -87,7 +87,7 @@ struct DeviceOperationInstanceFactory<
EDataType
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>>
Add
>>
{
{
using
DeviceOp
=
DeviceGroupedGemmFixedNK
<
ALayout
,
using
DeviceOp
=
DeviceGroupedGemmFixedNK
<
ALayout
,
BLayout
,
BLayout
,
...
@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
...
@@ -99,7 +99,7 @@ struct DeviceOperationInstanceFactory<
EDataType
,
EDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
PassThrough
,
Add
Bias
>
;
Add
>
;
static
auto
GetInstances
()
static
auto
GetInstances
()
{
{
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_kn_mn_instance.cpp
View file @
d1006d46
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f16_mk_nk_mn_instance.cpp
View file @
d1006d46
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_kn_mn_instance.cpp
View file @
d1006d46
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
library/src/tensor_operation_instance/gpu/grouped_gemm_bias/device_grouped_gemm_xdl_fixed_nk_bias_f16_f16_f32_mk_nk_mn_instance.cpp
View file @
d1006d46
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
...
@@ -31,7 +31,7 @@ using D0Layout = Row;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
DsLayout
=
ck
::
Tuple
<
D0Layout
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
Bias
;
using
Add
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
GemmMNKPadding
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment