Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e7be2fe8
Unverified
Commit
e7be2fe8
authored
Feb 10, 2023
by
pmaybank
Committed by
GitHub
Feb 10, 2023
Browse files
Merge branch 'develop' into sphinx_doc
parents
f68fa79a
f7d28f3e
Changes
347
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
779 additions
and
1 deletion
+779
-1
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
...nsor_operation_instance/gpu/elementwise_normalization.hpp
+79
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
...y/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
+0
-1
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
...brary/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
+145
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp
...brary/tensor_operation_instance/gpu/gemm_add_multiply.hpp
+155
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
...or_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
+172
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
...k/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
+138
-0
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
...ration_instance/gpu/grouped_convolution_backward_data.hpp
+90
-0
No files found.
Too many changes to show.
To preserve performance only
347 of 347+
files are displayed.
Plain diff
Email patch
library/include/ck/library/tensor_operation_instance/gpu/elementwise_normalization.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_elementwise_normalization_rank_2_1_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceElementwiseNormalization
<
ck
::
Tuple
<
F16
,
F16
>
,
F16
,
F16
,
F32
,
F16
,
element_wise
::
Add
,
PassThrough
,
2
,
1
>>>&
);
template
<
typename
InDataTypeTuple
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
YDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Add
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceElementwiseNormalization
<
InDataTypeTuple
,
GammaDataType
,
BetaDataType
,
F32
,
YDataType
,
ck
::
tensor_operation
::
element_wise
::
Add
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
GammaDataType
,
F16
>
&&
is_same_v
<
BetaDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
)
{
if
constexpr
(
Rank
==
2
&&
NumReduceDim
==
1
)
{
add_device_elementwise_normalization_rank_2_1_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_add_fastgelu.hpp
View file @
e7be2fe8
...
...
@@ -10,7 +10,6 @@
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
// GEMM + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_multiply.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddMultiply
>>>&
);
void
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddMultiply
>>>&
);
void
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddMultiply
>>>&
);
void
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddMultiply
>>>&
);
// GEMM + Add + Multiply
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddMultiply
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
EDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddMultiply
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
D1DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_mk_nk_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_multiply_xdl_c_shuffle_f16_f16_f16_f16_f16_km_nk_mn_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_relu_add_layernorm.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDLayernorm
<
Row
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
AddReluAdd
,
PassThrough
>>>&
);
void
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDLayernorm
<
Row
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
AddReluAdd
,
PassThrough
>>>&
);
void
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDLayernorm
<
Col
,
Row
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
AddReluAdd
,
PassThrough
>>>&
);
void
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleDLayernorm
<
Col
,
Col
,
Row_Row_Tuple
,
Row
,
F16
,
F16
,
F16_F16_Tuple
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
AddReluAdd
,
PassThrough
>>>&
);
// GEMM + Add + Relu + Add + Layernorm
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
D1Layout
,
typename
HLayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
D1DataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
HDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleDLayernorm
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
HLayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
GammaDataType
,
BetaDataType
,
HDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGemmMultipleDLayernorm
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
,
D1Layout
>
,
HLayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
,
D1DataType
>
,
GammaDataType
,
BetaDataType
,
HDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
AddReluAdd
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
D1DataType
,
half_t
>
&&
is_same_v
<
GammaDataType
,
half_t
>
&&
is_same_v
<
BetaDataType
,
half_t
>
&&
is_same_v
<
HDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
HLayout
,
Row
>
)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
HLayout
,
Row
>
)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_mk_nk_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
HLayout
,
Row
>
)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_kn_mn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
D1Layout
,
Row
>
&&
is_same_v
<
HLayout
,
Row
>
)
{
add_device_gemm_add_relu_add_xdl_c_shuffle_layernorm_f16_km_nk_mn_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
// GEMM + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/grouped_convolution_backward_data.hpp
0 → 100644
View file @
e7be2fe8
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// conv2d backward data
void
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGroupedConvBwdDataMultipleD
<
2
,
GNHWK
,
GKYXC
,
Empty_Tuple
,
GNHWC
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
template
<
ck
::
index_t
NumDimSpatial
,
typename
OutLayout
,
typename
WeiLayout
,
typename
InLayout
,
typename
OutDataType
,
typename
WeiDataType
,
typename
InDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>>
{
using
DeviceOp
=
DeviceGroupedConvBwdDataMultipleD
<
NumDimSpatial
,
OutLayout
,
WeiLayout
,
Empty_Tuple
,
InLayout
,
OutDataType
,
WeiDataType
,
Empty_Tuple
,
InDataType
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
NumDimSpatial
==
2
&&
is_same_v
<
InLayout
,
GNHWC
>
&&
is_same_v
<
WeiLayout
,
GKYXC
>
&&
is_same_v
<
OutLayout
,
GNHWK
>
)
{
if
constexpr
(
is_same_v
<
InDataType
,
F16
>
&&
is_same_v
<
WeiDataType
,
F16
>
&&
is_same_v
<
OutDataType
,
F16
>
)
{
add_device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
…
14
15
16
17
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment