Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
cd0c1f57
Unverified
Commit
cd0c1f57
authored
Apr 19, 2023
by
turneram
Committed by
GitHub
Apr 19, 2023
Browse files
Merge branch 'develop' into migx-device-interface
parents
c72a0d3e
bb0b772d
Changes
122
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
401 additions
and
154 deletions
+401
-154
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc
...ion/run_conv2d_fwd_bias_perlayer_quantization_example.inc
+1
-2
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc
...zation/run_conv2d_fwd_perchannel_quantization_example.inc
+1
-2
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
...tization/run_conv2d_fwd_perlayer_quantization_example.inc
+1
-2
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+2
-1
example/42_groupnorm/common.hpp
example/42_groupnorm/common.hpp
+23
-0
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
+56
-0
example/42_groupnorm/groupnorm_swish_fp16.cpp
example/42_groupnorm/groupnorm_swish_fp16.cpp
+40
-0
example/42_groupnorm/run_groupnorm_example.inc
example/42_groupnorm/run_groupnorm_example.inc
+7
-72
include/ck/ck.hpp
include/ck/ck.hpp
+8
-1
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
...eration/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
...impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
+1
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
.../impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
...k/tensor_operation/gpu/element/quantization_operation.hpp
+125
-6
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+29
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+19
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+19
-7
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+7
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
...or_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
+32
-28
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
+27
-17
No files found.
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_
relu_
perlayer_quantization_example.inc
→
example/40_conv2d_fwd_quantization/run_conv2d_fwd_bias_perlayer_quantization_example.inc
View file @
cd0c1f57
...
@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification,
...
@@ -155,7 +155,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
}
}
int
run_conv2d_fwd_bias_
relu_
perlayer_quantization_example
()
int
run_conv2d_fwd_bias_perlayer_quantization_example
(
const
OutElementOp
&
out_element_op
)
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
...
@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example()
...
@@ -177,7 +177,6 @@ int run_conv2d_fwd_bias_relu_perlayer_quantization_example()
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
out_element_op
=
OutElementOp
{
0.5
f
,
ActivationOp
{}};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perchannel_quantization_example.inc
View file @
cd0c1f57
...
@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification,
...
@@ -157,7 +157,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
}
}
int
run_conv2d_fwd_perchannel_quantization_example
()
int
run_conv2d_fwd_perchannel_quantization_example
(
const
OutElementOp
&
out_element_op
)
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
...
@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example()
...
@@ -179,7 +179,6 @@ int run_conv2d_fwd_perchannel_quantization_example()
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
out_element_op
=
OutElementOp
{
ActivationOp
{}};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
...
...
example/40_conv2d_fwd_quantization/run_conv2d_fwd_perlayer_quantization_example.inc
View file @
cd0c1f57
...
@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification,
...
@@ -139,7 +139,7 @@ bool run_grouped_conv_fwd(bool do_verification,
return
(
pass
?
0
:
1
);
return
(
pass
?
0
:
1
);
}
}
int
run_conv2d_fwd_perlayer_quantization_example
()
int
run_conv2d_fwd_perlayer_quantization_example
(
const
OutElementOp
&
out_element_op
)
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
bool
time_kernel
=
false
;
...
@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example()
...
@@ -161,7 +161,6 @@ int run_conv2d_fwd_perlayer_quantization_example()
const
auto
in_element_op
=
InElementOp
{};
const
auto
in_element_op
=
InElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
wei_element_op
=
WeiElementOp
{};
const
auto
out_element_op
=
OutElementOp
{
0.5
f
,
ActivationOp
{}};
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
GNHWC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
using
WeiLayout
=
ck
::
tensor_layout
::
convolution
::
GKYXC
;
...
...
example/42_groupnorm/CMakeLists.txt
View file @
cd0c1f57
add_example_executable
(
example_groupnorm_sigmoid_fp16 groupnorm_sigmoid_fp16.cpp
)
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
example/42_groupnorm/common.hpp
0 → 100644
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
example/42_groupnorm/groupnorm_sigmoid_mul_fp16.cpp
0 → 100644
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
y
=
x
*
a
;
};
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
#include "run_groupnorm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
run_groupnorm_example
(
argc
,
argv
);
}
example/42_groupnorm/groupnorm_swish_fp16.cpp
0 → 100644
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
using
YElementOp
=
ck
::
tensor_operation
::
element_wise
::
Swish
;
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
#include "run_groupnorm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
run_groupnorm_example
(
argc
,
argv
);
}
example/42_groupnorm/groupnorm_
sigmoid_fp16.cpp
→
example/42_groupnorm/
run_
groupnorm_
example.inc
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#pragma once
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <getopt.h>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_impl.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_groupnorm.hpp"
constexpr
int
Rank
=
5
;
constexpr
int
NumReduceDim
=
3
;
using
XDataType
=
ck
::
half_t
;
using
GammaDataType
=
ck
::
half_t
;
using
BetaDataType
=
ck
::
half_t
;
using
YDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
struct
YElementOp
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
ck
::
is_same
<
T
,
float
>::
value
||
ck
::
is_same
<
T
,
double
>::
value
||
ck
::
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
T
a
;
ck
::
tensor_operation
::
element_wise
::
Sigmoid
{}(
a
,
x
);
int
run_groupnorm_example
(
int
argc
,
char
*
argv
[])
y
=
x
*
a
;
};
};
using
DeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationImpl
<
XDataType
,
GammaDataType
,
BetaDataType
,
ComputeDataType
,
YDataType
,
YElementOp
,
Rank
,
NumReduceDim
,
1024
,
// BlockSize
1
,
// ClusterM
1024
,
// ClusterK
1
,
// SliceM
32
,
// SliceK
1
,
// SrcVecDim (0=M, 1=K)
2
,
// SrcScalarPerVector
1
,
// GammaVecDim (0=M, 1=K)
2
,
// GammaScalarPerVector
1
,
// BetaVecDim (0=M, 1=K)
2
,
// BetaScalarPerVector
2
>
;
// OutScalarPerVector
int
main
(
int
argc
,
char
*
argv
[])
{
{
ck
::
index_t
N
=
2
;
ck
::
index_t
N
=
3
2
;
ck
::
index_t
H
=
32
;
ck
::
index_t
H
=
16
;
ck
::
index_t
W
=
32
;
ck
::
index_t
W
=
16
;
ck
::
index_t
G
=
32
;
ck
::
index_t
G
=
64
;
ck
::
index_t
C
=
30
;
ck
::
index_t
C
=
128
;
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
...
...
include/ck/ck.hpp
View file @
cd0c1f57
...
@@ -36,7 +36,7 @@
...
@@ -36,7 +36,7 @@
#elif defined(__gfx1030__) // for GPU code
#elif defined(__gfx1030__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#elif defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) // for GPU code
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x100
20
000
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x
3
100
4
000
#endif
#endif
// FMA instruction
// FMA instruction
...
@@ -163,9 +163,16 @@
...
@@ -163,9 +163,16 @@
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#define CK_WORKAROUND_SWDEV_383542 1
#define CK_WORKAROUND_SWDEV_383542 1
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
#endif
namespace
ck
{
namespace
ck
{
enum
struct
InMemoryDataOperationEnum
enum
struct
InMemoryDataOperationEnum
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_wmma_cshuffle.hpp
View file @
cd0c1f57
...
@@ -840,17 +840,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
...
@@ -840,17 +840,8 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
<<
KPerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
getConvForwardSpecializationString
(
ConvForwardSpecialization
)
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPerXDL
<<
", "
<<
NPerXDL
<<
", "
<<
MXdlPerWave
<<
", "
<<
NXdlPerWave
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferSrcScalarPerVector
<<
", "
<<
ABlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferSrcScalarPerVector
<<
BBlockTransferSrcScalarPerVector
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/element/quantization_operation.hpp
View file @
cd0c1f57
...
@@ -7,10 +7,30 @@ namespace ck {
...
@@ -7,10 +7,30 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
element_wise
{
namespace
element_wise
{
// Y = Sy * Qy
// W = Sw * Qw
// X = Sx * Qx
// B = Sb * Qb = Sw * Sx * Qb
// Where X, W, Y are float32, Qx, Qw, Qy are int8
// Sx, Sw, Sy are scale of x, w, y (float32), which is calculated from quantization range
// Qb is int32, scale of B is Sw * Sx for convenient
// Y = W @ X, where @ is convolution or matrix multiplication
// Sy * Qy = Sw * Qw @ Sx * Qx
// Qy = [(Sw*Sx)/Sy] * Qw @ Qx
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
template
<
typename
Activation
>
struct
Activation_Mul_Clamp
struct
Activation_Mul_Clamp
{
{
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = (Sw * Sx / Sz) * Activation(Qw @ Qx)
// requantScale_ = Sw * Sx / Sz
Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
{
...
@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp
...
@@ -45,8 +65,39 @@ struct Activation_Mul_Clamp
Activation
activationOp_
;
Activation
activationOp_
;
};
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Mul_Activation_Mul_Clamp
{
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Mul_Activation_Mul_Clamp
(
float
scale_z_inv
,
float
scaleAcc
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
scaleAcc_
(
scaleAcc
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
);
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
float
scaleAcc_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// Conv Perchannel quantization + Activation function which is piecewise linear function, such as
// relu, leaky relu ...etc
// relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
template
<
typename
Activation
>
struct
Activation_Mul2_Clamp
struct
Activation_Mul2_Clamp
{
{
...
@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp
...
@@ -76,9 +127,20 @@ struct Activation_Mul2_Clamp
};
};
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// For Activation function which is piecewise linear function, such as relu, leaky relu ...etc
// Activation(Sy * Qy) = Sy * Activation(Qy)
template
<
typename
Activation
>
template
<
typename
Activation
>
struct
Add_Activation_Mul_Clamp
struct
Add_Activation_Mul_Clamp
{
{
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
// Y = W @ X + B
// Sy * Qy = Sw * Qw @ Sx * Qx + Sw * Sx * Qb
// Qy = [(Sw*Sx)/Sy] * (Qw @ Qx + Qb)
// For activation, Z = Activaiton(Y)
// Sz * Qz = Activation(Sy * Qy)
// Qz = Sy / Sz * Activation(Qy) = [(Sw*Sx)/Sz] * Activation(Qw @ Qx + Qb)
Add_Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
Add_Activation_Mul_Clamp
(
float
requantScale
,
Activation
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
:
requantScale_
(
requantScale
),
activationOp_
(
activationOp
)
{
{
...
@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp
...
@@ -139,11 +201,18 @@ struct Add_Activation_Mul2_Clamp
};
};
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// For Activation function which is non piecewise linear function, such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy * Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
template
<
typename
Activation
>
struct
Add_Mul_Activation_Mul_Clamp
struct
Add_Mul_Activation_Mul_Clamp
{
{
Add_Mul_Activation_Mul_Clamp
(
float
requantScale1
,
float
requantScale2
,
Activation
activationOp
)
// Convolution + Activation (non piecewise linear function)
:
requantScale1_
(
requantScale1
),
requantScale2_
(
requantScale2
),
activationOp_
(
activationOp
)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
// Qz = S1 * Activation[Sacc * (Qw @ Qx + Qb)]
// Where S1 = 1 / Sz, Sacc = Sw * Sx
Add_Mul_Activation_Mul_Clamp
(
float
scale_z_inv
,
float
scaleAcc
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
scaleAcc_
(
scaleAcc
),
activationOp_
(
activationOp
)
{
{
}
}
...
@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
...
@@ -151,14 +220,64 @@ struct Add_Mul_Activation_Mul_Clamp
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
requantScale1_
*
y_fp32
;
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
__host__
__device__
constexpr
void
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc_
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
requantScale2_
*
y_fp32
,
-
128.
f
,
127.
f
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
float
scaleAcc_
;
Activation
activationOp_
;
};
// Conv Perchannel quantization + Activation function which is non piecewise linear function,
// such as TanH, Sigmoid ...etc
// If an activation is not piecewise linear function
// then Activation(Sy *Qy) != Sy * Activation(Qy)
template
<
typename
Activation
>
struct
Add_Mul2_Activation_Mul_Clamp
{
Add_Mul2_Activation_Mul_Clamp
(
float
scale_z_inv
,
Activation
activationOp
)
:
scale_z_inv_
(
scale_z_inv
),
activationOp_
(
activationOp
)
{
}
__host__
__device__
constexpr
void
operator
()(
int8_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
scaleAcc
)
const
{
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
y
=
ck
::
type_convert
<
int8_t
>
(
y_fp32
);
}
}
float
requantScale1_
;
__host__
__device__
constexpr
void
float
requantScale2_
;
operator
()(
int32_t
&
y
,
const
int32_t
&
x
,
const
int32_t
&
bias
,
const
float
&
scaleAcc
)
const
{
// CAUSION - We might type_convert to int8 in threadwise copy
// eg. GridwiseGemmDlMultipleD_km_kn_mn
float
y_fp32
=
ck
::
type_convert
<
float
>
(
x
+
bias
);
y_fp32
=
scaleAcc
*
y_fp32
;
activationOp_
(
y_fp32
,
y_fp32
);
y_fp32
=
math
::
clamp
(
scale_z_inv_
*
y_fp32
,
-
128.
f
,
127.
f
);
y
=
ck
::
type_convert
<
int32_t
>
(
y_fp32
);
}
float
scale_z_inv_
;
Activation
activationOp_
;
Activation
activationOp_
;
};
};
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
cd0c1f57
...
@@ -316,8 +316,36 @@ struct Sigmoid
...
@@ -316,8 +316,36 @@ struct Sigmoid
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
y
=
1
/
(
ck
::
type_convert
<
T
>
(
1
)
+
exp
(
-
x
));
};
};
};
int32_t
divider_
=
1
;
struct
TanH
{
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
ck
::
math
::
tanh
(
x
);
};
};
struct
Swish
{
Swish
(
float
beta
=
1.0
f
)
:
beta_
(
beta
)
{}
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
ck
::
half_t
>::
value
,
"Data type is not supported by this operation!"
);
y
=
x
/
(
ck
::
type_convert
<
T
>
(
1
)
+
ck
::
math
::
exp
(
-
beta_
*
x
));
};
float
beta_
=
1.0
f
;
};
};
}
// namespace element_wise
}
// namespace element_wise
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
cd0c1f57
...
@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -431,6 +431,9 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
...
@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -439,8 +442,13 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
b_block_desc_k0perblock_nperblock_k1
.
GetElementSpaceSize
(),
max_lds_align
);
return
(
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
constexpr
auto
c_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_space_size_aligned
*
sizeof
(
BDataType
));
cshuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
.
GetElementSpaceSize
(),
max_lds_align
);
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)),
c_block_space_size_aligned
*
sizeof
(
CShuffleDataType
));
}
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
...
@@ -497,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -497,6 +505,15 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
ADataType
)
<=
TwoGB
&&
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
BDataType
)
<=
TwoGB
&&
e_grid_desc_m_n
.
GetElementSpaceSize
()
*
sizeof
(
EDataType
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
return
true
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -92,6 +92,17 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// ABDataTypeAdjusted -> ABDataType throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using
ABDataTypeAdjusted
=
conditional_t
<
is_same_v
<
ABDataType
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
ABDataType
>
;
#else
using
ABDataTypeAdjusted
=
ABDataType
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
{
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
...
@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -397,7 +408,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
ABDataType
,
ABDataType
Adjusted
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -428,7 +439,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
ABDataType
,
ABDataType
,
ABDataType
,
ABDataType
Adjusted
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -458,11 +469,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
ABDataType
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
MfmaSelector
<
ABDataType
Adjusted
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
ABDataType
,
ABDataType
Adjusted
,
AccDataType
,
AccDataType
,
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
...
@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
...
@@ -480,10 +491,11 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
static_cast
<
ABDataTypeAdjusted
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ABDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
ABDataType
Adjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
cd0c1f57
...
@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -264,6 +264,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
}
}
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
constexpr
long_index_t
TwoGB
=
(
long_index_t
{
1
}
<<
31
);
if
(
!
(
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatA
)
<=
TwoGB
&&
b_grid_desc_k0_n_k1
.
GetElementSpaceSize
()
*
sizeof
(
FloatB
)
<=
TwoGB
))
{
return
false
;
}
return
true
;
return
true
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -166,15 +166,12 @@ __global__ void
...
@@ -166,15 +166,12 @@ __global__ void
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
const
CBlockClusterAdaptor
c_block_cluster_adaptor
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
constexpr
index_t
shared_block_size
=
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared
_block
,
p_shared
,
a_b_k0_m_k1_grid_desc
,
a_b_k0_m_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
b_b_k0_n_k1_grid_desc
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
...
@@ -183,16 +180,16 @@ __global__ void
...
@@ -183,16 +180,16 @@ __global__ void
c_element_op
,
c_element_op
,
c_block_cluster_adaptor
);
c_block_cluster_adaptor
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
c_block_cluster_adaptor
;
ignore
=
c_block_cluster_adaptor
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -264,6 +261,16 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
using
FloatABAdjusted
=
FloatAB
;
#endif
// M0/M1/M1Padding
// M0/M1/M1Padding
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
static
constexpr
auto
M1PerBlock
=
Number
<
ABlockLdsM1PerBlock
>
{};
static
constexpr
auto
M0PerBlock
=
Number
<
ABlockLdsM0PerBlock
>
{};
static
constexpr
auto
M0PerBlock
=
Number
<
ABlockLdsM0PerBlock
>
{};
...
@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -605,7 +612,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatAB
*
__restrict__
p_shared
_block
,
void
*
__restrict__
p_shared
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
AGridDesc_B_K0_M_K1
&
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
&
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
...
@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -666,7 +673,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_grid_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
decltype
(
a_b_k0_m_k1_block_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -696,7 +703,7 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_grid_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
decltype
(
b_b_k0_n_k1_block_desc
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -725,11 +732,11 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
// sanity check
// sanity check
constexpr
index_t
KPack
=
constexpr
index_t
KPack
=
math
::
max
(
K1
,
MfmaSelector
<
FloatAB
,
MPerXDL
,
NPerXDL
>::
selected_mfma
.
k_per_blk
);
math
::
max
(
K1
,
MfmaSelector
<
FloatAB
Adjusted
,
MPerXDL
,
NPerXDL
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
Adjusted
,
FloatAcc
,
FloatAcc
,
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
a_k0_m_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
decltype
(
b_k0_n_k1_block_desc
),
...
@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -745,16 +752,15 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_k0_m_k1_block_desc
.
GetElementSpaceSize
(),
max_lds_align
);
FloatAB
*
p_a_block
=
p_shared_block
;
FloatAB
*
p_b_block
=
p_shared_block
+
a_block_space_size
;
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
0
,
K0PerBlock
,
0
,
0
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_a_block
,
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
static_cast
<
FloatABAdjusted
*>
(
p_shared
),
a_k0_m_k1_block_desc
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_b_block
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
static_cast
<
FloatABAdjusted
*>
(
p_shared
)
+
a_block_space_size
,
b_k0_n_k1_block_desc
.
GetElementSpaceSize
());
// gridwise GEMM pipeline
// gridwise GEMM pipeline
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
...
@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
...
@@ -798,8 +804,6 @@ struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
constexpr
auto
c_block_desc_mblock_mperblock_nblock_nperblock
=
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
GetCBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
void
*
p_shared
=
static_cast
<
void
*>
(
p_shared_block
);
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
c_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatC
*>
(
p_shared
),
static_cast
<
FloatC
*>
(
p_shared
),
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
c_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp
View file @
cd0c1f57
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -58,16 +58,16 @@ __global__ void
...
@@ -58,16 +58,16 @@ __global__ void
c_element_op
,
c_element_op
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -131,6 +131,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
,
LoopSched
>
())
>
;
// denorm test fix, required to work around fp16 mfma issue
// we convert fp16->fp32->bf16 and execute bf16 mfma instruction
// when mfma if fixed, remove this section and update
// FloatABAdjusted -> FloatAB throughout this file
#if CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
using
FloatABAdjusted
=
conditional_t
<
is_same_v
<
FloatAB
,
ck
::
half_t
>
,
ck
::
bhalf_t
,
FloatAB
>
;
#else
using
FloatABAdjusted
=
FloatAB
;
#endif
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
K1
;
...
@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -281,7 +291,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
using
BlockwiseGemm
=
using
BlockwiseGemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
Adjusted
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -367,7 +377,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -398,7 +408,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
Adjusted
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
...
@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -428,7 +438,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
// sanity check
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
BlockSize
,
FloatAB
,
FloatAB
Adjusted
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_m_k1
),
decltype
(
b_block_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_n_k1
),
...
@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
...
@@ -446,10 +456,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_m_k1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
Adjusted
*>
(
p_shared
),
a_block_desc_k0_m_k1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
FloatAB
Adjusted
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
b_block_desc_k0_n_k1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment