Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
9ba504b6
"driver/vscode:/vscode.git/clone" did not exist on "9e0d61465612000caf2050cf834f4693792c5770"
Commit
9ba504b6
authored
Feb 07, 2025
by
ThomasNing
Browse files
merge with the develop support the fp8 with computev4
parents
e3402c93
f49de496
Changes
198
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
267 additions
and
84 deletions
+267
-84
example/ck_tile/03_gemm/universal_gemm.cpp
example/ck_tile/03_gemm/universal_gemm.cpp
+91
-8
include/ck/README.md
include/ck/README.md
+20
-16
include/ck/ck.hpp
include/ck/ck.hpp
+12
-2
include/ck/config.h.in
include/ck/config.h.in
+4
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+4
-3
include/ck/library/utility/check_err.hpp
include/ck/library/utility/check_err.hpp
+75
-24
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+43
-0
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
...ation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
..._batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
...u/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
...device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
...ce/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...u/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
...r_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
+2
-3
No files found.
example/ck_tile/03_gemm/universal_gemm.cpp
View file @
9ba504b6
...
@@ -12,7 +12,13 @@
...
@@ -12,7 +12,13 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "gemm_basic.hpp"
#include "gemm_basic.hpp"
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
float
gemm_calc
(
const
ck_tile
::
GemmHostArgs
&
args
,
const
ck_tile
::
stream_config
&
s
)
{
{
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
#if(CK_TILE_PIPELINE_DEFAULT == CK_TILE_PIPELINE_MEMORY)
...
@@ -59,7 +65,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
...
@@ -59,7 +65,7 @@ float gemm_calc(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config&
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
M_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
N_Warp_Tile
=
32
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
8
;
constexpr
ck_tile
::
index_t
K_Warp_Tile
=
16
;
constexpr
bool
DoubleSmemBuffer
=
true
;
constexpr
bool
DoubleSmemBuffer
=
true
;
#endif
#endif
...
@@ -279,24 +285,101 @@ int run_gemm_example(int argc, char* argv[])
...
@@ -279,24 +285,101 @@ int run_gemm_example(int argc, char* argv[])
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
;
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
data_type
=
arg_parser
.
get_str
(
"prec"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
std
::
string
a_layout
=
arg_parser
.
get_str
(
"a_layout"
);
std
::
string
b_layout
=
arg_parser
.
get_str
(
"b_layout"
);
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
if
(
a_layout
==
"R"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"R"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Row
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"C"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Col
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
else
if
(
a_layout
==
"C"
&&
b_layout
==
"R"
)
{
{
return
run_gemm_example_with_layouts
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
if
(
data_type
==
"fp16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
half_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf16"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf16_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"fp8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
fp8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
if
(
data_type
==
"bf8"
)
{
return
run_gemm_example_with_layouts
<
ck_tile
::
bf8_t
>
(
argc
,
argv
,
Col
{},
Row
{},
Row
{});
}
else
{
throw
std
::
runtime_error
(
"Unsupported data_type!"
);
}
}
}
else
else
{
{
...
...
include/ck/README.md
View file @
9ba504b6
[
Back to the main page
](
../../README.md
)
[
Back to the main page
](
../../README.md
)
# Composable Kernel supported operations
# Composable Kernel supported operations
## Supported device operations
## Supported device operations
*
[
Average pooling
](
)
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
*
[
Batched contraction
](
)
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
*
[
Batched gemm
](
)
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
*
[
Batchnorm
](
)
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
*
[
CGEMM
](
)
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
*
[
Contraction
](
)
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
*
[
Convolution
](
)
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
*
[
Image to Column and Column to Image
](
)
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
*
[
Elementwise
](
)
*
[
GEMM
](
../../client_example/01_gemm/README.md
)
*
[
GEMM
](
)
*
[
Grouped Convolution Forward
](
../../client_example/07_grouped_convnd_fwd/README.md
)
*
[
Max pooling
](
)
*
[
Grouped Convolution Backward Data
](
../../client_example/10_grouped_convnd_bwd_data/README.md
)
*
[
Reduce
](
)
*
[
Grouped Convolution Backward Weight
](
../../client_example/11_grouped_conv_bwd_weight/README.md
)
*
[
Normalization
](
)
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
*
[
Permute
](
)
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
*
[
Put
](
)
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
*
[
Softmax
](
)
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
include/ck/ck.hpp
View file @
9ba504b6
...
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// define general macros for various architectures
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#define __gfx9__
#endif
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#define __gfx94__
#endif
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
...
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
...
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
include/ck/config.h.in
View file @
9ba504b6
...
@@ -131,6 +131,10 @@
...
@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
// clang-format on
#endif // CK_CONFIG_H_IN
#endif // CK_CONFIG_H_IN
include/ck/host_utility/device_prop.hpp
View file @
9ba504b6
...
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
...
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
{
{
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_lds_direct_load_supported
()
inline
bool
is_lds_direct_load_supported
()
{
{
// Check if direct loads from global memory to LDS are supported.
// Check if direct loads from global memory to LDS are supported.
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_bf16_atomic_supported
()
inline
bool
is_bf16_atomic_supported
()
{
{
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
}
inline
bool
is_gfx101_supported
()
inline
bool
is_gfx101_supported
()
...
...
include/ck/library/utility/check_err.hpp
View file @
9ba504b6
...
@@ -26,6 +26,7 @@ namespace utils {
...
@@ -26,6 +26,7 @@ namespace utils {
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I8
=
int8_t
;
using
I32
=
int32_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
int
>
,
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
double
compute_error
=
0
;
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
int
>
,
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
int
>
,
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
return
res
;
return
res
;
}
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f4_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
0.5
,
double
atol
=
0.5
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
}
// namespace utils
}
// namespace utils
}
// namespace ck
}
// namespace ck
include/ck/library/utility/host_tensor_generator.hpp
View file @
9ba504b6
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_1
<
ck
::
f4_t
>
{
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f4_t
>
(
value
);
}
};
template
<
>
template
<
>
struct
GeneratorTensor_1
<
int8_t
>
struct
GeneratorTensor_1
<
int8_t
>
{
{
...
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
...
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_2
<
ck
::
f4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
type_convert
<
ck
::
f4_t
>
(
tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_3
struct
GeneratorTensor_3
{
{
...
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
...
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
};
};
#endif
#endif
template
<
>
struct
GeneratorTensor_3
<
ck
::
f4_t
>
{
float
min_value
=
0
;
float
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
type_convert
<
ck
::
f4_t
>
(
fp32_tmp
);
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
GeneratorTensor_4
struct
GeneratorTensor_4
{
{
...
...
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -94,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
...
@@ -94,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
const
Block2ETileMap
block_2_ctile_map
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
// offset base pointer for each work-group
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -56,8 +56,7 @@ __global__ void
...
@@ -56,8 +56,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
9ba504b6
...
@@ -74,8 +74,7 @@ __global__ void
...
@@ -74,8 +74,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -60,8 +60,7 @@ __global__ void
...
@@ -60,8 +60,7 @@ __global__ void
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -108,7 +107,7 @@ __global__ void
...
@@ -108,7 +107,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multi_d_xdl.hpp
View file @
9ba504b6
...
@@ -83,8 +83,7 @@ __global__ void
...
@@ -83,8 +83,7 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -68,8 +68,7 @@ __global__ void
...
@@ -68,8 +68,7 @@ __global__ void
const
index_t
batch_count
,
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -59,8 +59,7 @@ __global__ void
...
@@ -59,8 +59,7 @@ __global__ void
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -67,8 +67,7 @@ __global__ void
...
@@ -67,8 +67,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -127,7 +126,7 @@ __global__ void
...
@@ -127,7 +126,7 @@ __global__ void
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -62,8 +62,7 @@ __global__ void
...
@@ -62,8 +62,7 @@ __global__ void
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
,
const
C0MatrixMask
c0_matrix_mask
)
const
C0MatrixMask
c0_matrix_mask
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
@@ -112,7 +111,7 @@ __global__ void
...
@@ -112,7 +111,7 @@ __global__ void
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
compute_base_ptr_of_batch
;
ignore
=
c0_matrix_mask
;
ignore
=
c0_matrix_mask
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// Computes C = A * B0 * B1
// Computes C = A * B0 * B1
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp
View file @
9ba504b6
...
@@ -52,8 +52,7 @@ __global__ void
...
@@ -52,8 +52,7 @@ __global__ void
#endif
#endif
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
kernel_batched_gemm_xdlops_v2r3
(
const
typename
DeviceOp
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
karg
.
Batch
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
karg
.
Batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
9ba504b6
...
@@ -55,8 +55,7 @@ __global__ void
...
@@ -55,8 +55,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
9ba504b6
...
@@ -55,8 +55,7 @@ __global__ void
...
@@ -55,8 +55,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
num_batches
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
@@ -97,7 +96,7 @@ __global__ void
...
@@ -97,7 +96,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
// specialization for #D conv: in[n, di, hi, wi, c] * wei[k, z, y, x, c] = out[n, do, ho, wo, k]
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp
View file @
9ba504b6
...
@@ -50,9 +50,8 @@ __global__ void
...
@@ -50,9 +50,8 @@ __global__ void
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
CGridDesc_M0_M10_M11_N0_N10_N11
e_grid_desc_m0_m10_m11_n0_n10_n11
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
defined(__gfx12__))
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
ABDataType
);
...
...
Prev
1
2
3
4
5
6
7
…
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment