Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f23a2e2a
Commit
f23a2e2a
authored
Feb 11, 2025
by
Jakub Piasecki
Browse files
resolved conflicts
parents
f3eb5a18
c0adab48
Changes
340
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
584 additions
and
100 deletions
+584
-100
example/ck_tile/35_batched_transpose/batched_transpose_example.cpp
...k_tile/35_batched_transpose/batched_transpose_example.cpp
+261
-0
example/ck_tile/35_batched_transpose/batched_transpose_example.hpp
...k_tile/35_batched_transpose/batched_transpose_example.hpp
+25
-0
example/ck_tile/35_batched_transpose/script/smoke_test.sh
example/ck_tile/35_batched_transpose/script/smoke_test.sh
+11
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/README.md
include/ck/README.md
+20
-16
include/ck/ck.hpp
include/ck/ck.hpp
+25
-10
include/ck/config.h.in
include/ck/config.h.in
+4
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+4
-3
include/ck/library/utility/check_err.hpp
include/ck/library/utility/check_err.hpp
+75
-24
include/ck/library/utility/host_tensor_generator.hpp
include/ck/library/utility/host_tensor_generator.hpp
+43
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
+2
-2
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+5
-1
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+9
-4
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+42
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
...ation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
+18
-4
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
...de/ck/tensor_operation/gpu/device/gemm_specialization.hpp
+3
-1
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+32
-28
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
...gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
+2
-3
No files found.
example/ck_tile/35_batched_transpose/batched_transpose_example.cpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose_example.hpp"
#if 0
template <typename T>
void dump_host_tensor_4d(const ck_tile::HostTensor<T>& x)
{
auto len = x.get_lengths();
assert(len.size() == 4);
std::cout << "[";
for(size_t i = 0; i < len[0]; i++)
{
std::cout << i << ": [";
for(size_t j = 0; j < len[1]; j++)
{
std::cout << j << ": [";
for(size_t k = 0; k < len[2]; k++)
{
std::cout << k << ": [";
for(size_t v = 0; v < len[3]; v++)
{
if constexpr(std::is_same_v<T, ck_tile::fp16_t>)
{
auto m =
ck_tile::type_convert<float>(x(std::vector<std::size_t>{i, j, k, v}));
std::cout << m;
if(v != len[3] - 1)
std::cout << ",";
}
else
{
std::cout << x(std::vector<std::size_t>{i, j, k, v}) << " ";
}
}
std::cout << "]" << std::endl;
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}
std::cout << "--------------------" << std::endl;
}
#endif
// different threshold for different dtype
template
<
typename
DataType
>
auto
get_elimit
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-3
;
double
atol
=
1e-3
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
bf16_t
>
(
std
::
string
/*init_method*/
)
{
double
rtol
=
1e-2
;
double
atol
=
1e-2
;
return
ck_tile
::
make_tuple
(
rtol
,
atol
);
}
template
<
>
auto
get_elimit
<
ck_tile
::
fp8_t
>
(
std
::
string
init_method
)
{
if
(
init_method
==
"ui"
||
init_method
==
"ni"
)
{
unsigned
max_rounding_point_distance
=
0
;
double
atol
=
2e-3
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
else
{
unsigned
max_rounding_point_distance
=
1
;
double
atol
=
0.0625
;
return
ck_tile
::
make_tuple
(
max_rounding_point_distance
,
atol
);
}
}
auto
create_args
(
int
argc
,
char
*
argv
[])
{
ck_tile
::
ArgParser
arg_parser
;
arg_parser
.
insert
(
"v"
,
"1"
,
"whether do CPU validation or not"
)
.
insert
(
"pr"
,
"fp16"
,
"input data type. fp16/fp32 (representing 8/16/32 bit data)"
)
.
insert
(
"N"
,
"2"
,
"input batch size. "
)
.
insert
(
"C"
,
"16"
,
"input channel size."
)
.
insert
(
"H"
,
"1"
,
"input height size."
)
.
insert
(
"W"
,
"16"
,
"input width size. "
)
.
insert
(
"layout_in"
,
"NCHW"
,
"input tensor data layout - NCHW by default"
)
.
insert
(
"layout_out"
,
"NHWC"
,
"output tensor data layout - NHWC by default "
)
.
insert
(
"seed"
,
"-1"
,
"seed to be used, -1 means random every time"
)
.
insert
(
"kname"
,
"0"
,
"t to 1 will print kernel name"
);
bool
result
=
arg_parser
.
parse
(
argc
,
argv
);
return
std
::
make_tuple
(
result
,
arg_parser
);
}
template
<
typename
Type
>
bool
run_batched_transpose
(
ck_tile
::
ArgParser
args
)
{
int
validate
=
args
.
get_int
(
"v"
);
std
::
string
prec
=
args
.
get_str
(
"pr"
);
int
N
=
args
.
get_int
(
"N"
);
int
C
=
args
.
get_int
(
"C"
);
int
H
=
args
.
get_int
(
"H"
);
int
W
=
args
.
get_int
(
"W"
);
std
::
string
layout_in
=
args
.
get_str
(
"layout_in"
);
std
::
string
layout_out
=
args
.
get_str
(
"layout_out"
);
int
seed
=
args
.
get_int
(
"seed"
);
int
dim_in
[
4
],
dim_out
[
4
];
int
stride_dim_in
[
4
],
stride_dim_out
[
4
];
bool
nchw2nhwc
=
layout_in
==
"NCHW"
&&
layout_out
==
"NHWC"
;
bool
nhwc2nchw
=
layout_in
==
"NHWC"
&&
layout_out
==
"NCHW"
;
assert
(
nchw2nhwc
!=
nhwc2nchw
);
(
void
)
nhwc2nchw
;
dim_in
[
0
]
=
N
;
dim_in
[
1
]
=
nchw2nhwc
?
C
:
H
;
dim_in
[
2
]
=
nchw2nhwc
?
H
:
W
;
dim_in
[
3
]
=
nchw2nhwc
?
W
:
C
;
dim_out
[
0
]
=
N
;
dim_out
[
1
]
=
nchw2nhwc
?
H
:
C
;
dim_out
[
2
]
=
nchw2nhwc
?
W
:
H
;
dim_out
[
3
]
=
nchw2nhwc
?
C
:
W
;
stride_dim_in
[
0
]
=
C
*
H
*
W
;
stride_dim_in
[
1
]
=
nchw2nhwc
?
H
*
W
:
C
*
W
;
stride_dim_in
[
2
]
=
nchw2nhwc
?
W
:
C
;
stride_dim_in
[
3
]
=
1
;
stride_dim_out
[
0
]
=
C
*
H
*
W
;
stride_dim_out
[
1
]
=
nchw2nhwc
?
C
*
W
:
H
*
W
;
stride_dim_out
[
2
]
=
nchw2nhwc
?
C
:
W
;
stride_dim_out
[
3
]
=
1
;
if
(
seed
<
0
)
{
seed
=
std
::
time
(
nullptr
);
}
ck_tile
::
HostTensor
<
Type
>
x_host
(
{
dim_in
[
0
],
dim_in
[
1
],
dim_in
[
2
],
dim_in
[
3
]},
{
stride_dim_in
[
0
],
stride_dim_in
[
1
],
stride_dim_in
[
2
],
stride_dim_in
[
3
]});
ck_tile
::
HostTensor
<
Type
>
y_host
(
{
dim_out
[
0
],
dim_out
[
1
],
dim_out
[
2
],
dim_out
[
3
]},
{
stride_dim_out
[
0
],
stride_dim_out
[
1
],
stride_dim_out
[
2
],
stride_dim_out
[
3
]});
ck_tile
::
FillUniformDistribution
<
Type
>
{
-
.5
f
,
.5
f
}(
x_host
);
ck_tile
::
DeviceMem
x_dev
(
x_host
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
y_dev
(
y_host
.
get_element_space_size_in_bytes
());
x_dev
.
ToDevice
(
x_host
.
data
());
auto
trait
=
batched_transpose_trait
{
prec
,
layout_in
};
uint32_t
height
=
nchw2nhwc
?
C
:
H
*
W
;
uint32_t
width
=
nchw2nhwc
?
H
*
W
:
C
;
batched_transpose_kargs
karg
=
[
&
]()
{
batched_transpose_kargs
a_
;
a_
.
p_input
=
x_dev
.
GetDeviceBuffer
();
a_
.
p_output
=
y_dev
.
GetDeviceBuffer
();
a_
.
batch
=
N
;
a_
.
height
=
height
;
a_
.
width
=
width
;
return
a_
;
}();
ck_tile
::
stream_config
sc
{
nullptr
,
true
};
auto
ms
=
batched_transpose
(
trait
,
karg
,
sc
);
std
::
size_t
num_operations
=
N
*
C
*
H
*
(
W
-
1
);
std
::
size_t
num_bytes
=
N
*
C
*
H
*
W
*
sizeof
(
Type
);
float
ave_time
=
ms
*
1E-3
;
float
gb_per_sec
=
num_bytes
/
ms
*
1.E-6
;
float
tflops
=
static_cast
<
float
>
(
num_operations
)
/
ms
*
1.E-6
;
std
::
cout
<<
"Run Batched Transpose kernel with N="
<<
N
<<
", C="
<<
C
<<
", H="
<<
H
<<
", W="
<<
W
<<
", layout_in="
<<
layout_in
<<
", layout_out="
<<
layout_out
<<
" : "
<<
ms
<<
" ms ("
<<
ave_time
<<
" ave_time), "
<<
tflops
<<
" TFlops"
<<
gb_per_sec
<<
" GB/s, "
<<
std
::
endl
;
printf
(
"[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f
\n
"
,
prec
.
c_str
(),
N
,
C
,
H
,
W
,
layout_in
.
c_str
(),
ms
);
if
(
ms
<
0
)
printf
(
"not supported
\n
"
);
fflush
(
stdout
);
if
(
ms
<
0
)
{
return
false
;
}
y_dev
.
FromDevice
(
y_host
.
data
());
bool
rtn
=
true
;
if
(
validate
)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile
::
HostTensor
<
Type
>
y_ref
(
{
dim_out
[
0
],
dim_out
[
1
],
dim_out
[
2
],
dim_out
[
3
]},
{
stride_dim_out
[
0
],
stride_dim_out
[
1
],
stride_dim_out
[
2
],
stride_dim_out
[
3
]});
ck_tile
::
reference_batched_transpose
<
Type
>
(
x_host
,
y_ref
,
layout_in
,
layout_out
);
auto
[
rtol
,
atol
]
=
get_elimit
<
Type
>
(
""
);
rtn
&=
ck_tile
::
check_err
(
y_host
,
y_ref
,
std
::
string
(
"y Error: Incorrect results!"
),
rtol
,
atol
);
}
printf
(
"valid:%s
\n
"
,
rtn
?
"y"
:
"n"
);
fflush
(
stdout
);
return
rtn
;
}
int
main
(
int
argc
,
char
**
argv
)
{
auto
[
result
,
args
]
=
create_args
(
argc
,
argv
);
if
(
!
result
)
return
-
1
;
std
::
string
prec
=
args
.
get_str
(
"pr"
);
bool
r
=
true
;
if
(
prec
.
compare
(
"fp32"
)
==
0
)
{
r
&=
run_batched_transpose
<
float
>
(
args
);
}
else
if
(
prec
.
compare
(
"fp16"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
fp16_t
>
(
args
);
}
else
if
(
prec
.
compare
(
"bf16"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
bf16_t
>
(
args
);
}
else
if
(
prec
.
compare
(
"int8"
)
==
0
)
{
r
&=
run_batched_transpose
<
ck_tile
::
int8_t
>
(
args
);
}
return
r
?
0
:
-
1
;
}
example/ck_tile/35_batched_transpose/batched_transpose_example.hpp
0 → 100644
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct
batched_transpose_trait
{
std
::
string
type
;
std
::
string
layout
;
};
struct
batched_transpose_kargs
:
public
ck_tile
::
BatchedTransposeHostArgs
{
};
float
batched_transpose
(
batched_transpose_trait
t
,
batched_transpose_kargs
a
,
ck_tile
::
stream_config
s
);
example/ck_tile/35_batched_transpose/script/smoke_test.sh
0 → 100755
View file @
f23a2e2a
#!/bin/sh
EXE
=
./build/bin/tile_example_batched_transpose
for
pr
in
"fp32"
"fp16"
"int8"
;
do
$EXE
-pr
=
$pr
-N
=
1
-C
=
32
-H
=
1
-W
=
32
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
$EXE
-pr
=
$pr
-N
=
2
-C
=
12
-H
=
1
-W
=
32
-layout_in
=
'NHWC'
-layout_out
=
'NCHW'
$EXE
-pr
=
$pr
-N
=
3
-C
=
1334
-H
=
1
-W
=
37
-layout_in
=
'NHWC'
-layout_out
=
'NCHW'
$EXE
-pr
=
$pr
-N
=
4
-C
=
27
-H
=
1
-W
=
32
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
$EXE
-pr
=
$pr
-N
=
5
-C
=
1234
-H
=
1
-W
=
12
-layout_in
=
'NCHW'
-layout_out
=
'NHWC'
done
example/ck_tile/CMakeLists.txt
View file @
f23a2e2a
...
...
@@ -17,3 +17,4 @@ add_subdirectory(14_moe_smoothquant)
add_subdirectory
(
15_fused_moe
)
add_subdirectory
(
16_batched_gemm
)
add_subdirectory
(
17_grouped_gemm
)
add_subdirectory
(
35_batched_transpose
)
include/ck/README.md
View file @
f23a2e2a
[
Back to the main page
](
../../README.md
)
# Composable Kernel supported operations
## Supported device operations
*
[
Average pooling
](
)
*
[
Batched contraction
](
)
*
[
Batched gemm
](
)
*
[
Batchnorm
](
)
*
[
CGEMM
](
)
*
[
Contraction
](
)
*
[
Convolution
](
)
*
[
Image to Column and Column to Image
](
)
*
[
Elementwise
](
)
*
[
GEMM
](
)
*
[
Max pooling
](
)
*
[
Reduce
](
)
*
[
Normalization
](
)
*
[
Permute
](
)
*
[
Put
](
)
*
[
Softmax
](
)
<!-- * [Average pooling](../../docs/markdown/tensor_operation/average_pooling.md) -->
<!-- * [Batched contraction](../../docs/markdown/tensor_operation/batched_contraction.md) -->
<!-- * [Batched gemm](../../docs/markdown/tensor_operation/batched_gemm.md) -->
<!-- * [Batchnorm](../../docs/markdown/tensor_operation/batchnorm.md) -->
<!-- * [CGEMM](../../docs/markdown/tensor_operation/cgemm.md) -->
<!-- * [Contraction](../../docs/markdown/tensor_operation/contraction.md) -->
<!-- * [Convolution](../../docs/markdown/tensor_operation/convolution.md) -->
<!-- * [Elementwise](../../docs/markdown/tensor_operation/elementwise.md) -->
*
[
GEMM
](
../../client_example/01_gemm/README.md
)
*
[
Grouped Convolution Forward
](
../../client_example/07_grouped_convnd_fwd/README.md
)
*
[
Grouped Convolution Backward Data
](
../../client_example/10_grouped_convnd_bwd_data/README.md
)
*
[
Grouped Convolution Backward Weight
](
../../client_example/11_grouped_conv_bwd_weight/README.md
)
<!-- * [Grouped GEMM](../../docs/markdown/tensor_operation/grouped_gemm.md) -->
<!-- * [Image to Column and Column to Image](../../docs/markdown/tensor_operation/img2col.md) -->
<!-- * [Max pooling](../../docs/markdown/tensor_operation/max_pooling.md) -->
<!-- * [Reduce](../../docs/markdown/tensor_operation/reduce.md) -->
<!-- * [Normalization](../../docs/markdown/tensor_operation/normalization.md) -->
<!-- * [Permute](../../docs/markdown/tensor_operation/permute.md) -->
<!-- * [Put](../../docs/markdown/tensor_operation/put.md) -->
<!-- * [Softmax](../../docs/markdown/tensor_operation/softmax.md) -->
include/ck/ck.hpp
View file @
f23a2e2a
...
...
@@ -5,7 +5,7 @@
#include "ck/config.h"
#include "ck/utility/env.hpp"
#ifndef CK_CODE_GEN_RTC
#ifndef CK_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
...
...
@@ -14,7 +14,7 @@
// environment variable to enable logging:
// export CK_LOGGING=ON or CK_LOGGING=1 or CK_LOGGING=ENABLED
CK_DECLARE_ENV_VAR_BOOL
(
CK_LOGGING
)
#endif
// to do: add various levels of logging with CK_LOG_LEVEL
#ifndef CK_TIME_KERNEL
...
...
@@ -55,10 +55,10 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// define general macros for various architectures
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__)
defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx9__
#endif
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
|| defined(__gfx950__)
#define __gfx94__
#endif
#if defined(__gfx1010__) || defined(__gfx1011__) || defined(__gfx1012__)
...
...
@@ -163,6 +163,16 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// set rounding to nearest even as default for f8 conversions
#define CK_USE_SR_F8_CONVERSION 0
// set rounding to nearest even as default for f6 conversions
#define CK_USE_SR_F6_CONVERSION 0
// set rounding to nearest even as default for f4 conversions
#define CK_USE_SR_F4_CONVERSION 0
// shuffle pk_i4 values during conversion to optimize number of binary
// operations
#define CK_USE_PK4_LAYOUT_SHUFFLE 1
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
...
@@ -235,13 +245,18 @@ CK_DECLARE_ENV_VAR_BOOL(CK_LOGGING)
// workaround: compiler issue on gfx908
#define CK_WORKAROUND_SWDEV_388832 1
// denorm test fix, required to work around dissue
#ifndef CK_WORKAROUND_DENORM_FIX
#define CK_WORKAROUND_DENORM_FIX 0
// denorm test fix, necessary for gfx90a
#ifndef CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 0
#endif // CK_GFX90A_DENORM_WORKAROUND
// Enable only for gfx90a
#if defined(__gfx90a__)
#if CK_GFX90A_DENORM_WORKAROUND
#define CK_GFX90A_DENORM_WORKAROUND 1
#endif // CK_GFX90A_DENORM_WORKAROUND is set to 1
#else
// enable only for gfx90a
#define CK_WORKAROUND_DENORM_FIX = CK_WORKAROUND_DENORM_FIX && defined(__gfx90a__)
#endif // CK_WORKAROUND_DENORM_FIX
#define CK_GFX90A_DENORM_WORKAROUND 0
#endif // gfx90a
// set flag to 1 to build deprecated instances
#define CK_BUILD_DEPRECATED 1
...
...
include/ck/config.h.in
View file @
f23a2e2a
...
...
@@ -131,6 +131,10 @@
#cmakedefine CK_USE_FP8_ON_UNSUPPORTED_ARCH @CK_USE_FP8_ON_UNSUPPORTED_ARCH@
#endif
#ifndef CK_USE_NATIVE_MX_SUPPORT
#cmakedefine CK_USE_NATIVE_MX_SUPPORT @CK_USE_NATIVE_MX_SUPPORT@
#endif
// clang-format on
#endif // CK_CONFIG_H_IN
include/ck/host_utility/device_prop.hpp
View file @
f23a2e2a
...
...
@@ -55,20 +55,21 @@ inline bool is_xdl_supported()
{
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
inline
bool
is_lds_direct_load_supported
()
{
// Check if direct loads from global memory to LDS are supported.
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
inline
bool
is_bf16_atomic_supported
()
{
return
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
||
ck
::
get_device_name
()
==
"gfx950"
;
}
inline
bool
is_gfx101_supported
()
...
...
include/ck/library/utility/check_err.hpp
View file @
f23a2e2a
...
...
@@ -26,6 +26,7 @@ namespace utils {
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_relative_threshold
(
const
int
number_of_accumulations
=
1
)
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
...
...
@@ -33,10 +34,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!"
);
double
compute_error
=
0
;
if
constexpr
(
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
...
...
@@ -49,10 +50,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
compute_error
=
std
::
pow
(
2
,
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the relative threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
...
@@ -66,10 +67,10 @@ double get_relative_threshold(const int number_of_accumulations = 1)
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the relative threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
...
@@ -87,6 +88,7 @@ double get_relative_threshold(const int number_of_accumulations = 1)
template
<
typename
ComputeDataType
,
typename
OutDataType
,
typename
AccDataType
=
ComputeDataType
>
double
get_absolute_threshold
(
const
double
max_possible_num
,
const
int
number_of_accumulations
=
1
)
{
using
F4
=
ck
::
f4_t
;
using
F8
=
ck
::
f8_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
...
...
@@ -94,10 +96,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
using
I8
=
int8_t
;
using
I32
=
int32_t
;
static_assert
(
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F
16
>
||
is_same_v
<
ComputeDataType
,
B
F16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I8
>
||
is_same_v
<
ComputeDataType
,
I
32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
static_assert
(
is_same_v
<
ComputeDataType
,
F
4
>
||
is_same_v
<
ComputeDataType
,
F
8
>
||
is_same_v
<
ComputeDataType
,
F16
>
||
is_same_v
<
ComputeDataType
,
BF16
>
||
is_same_v
<
ComputeDataType
,
F32
>
||
is_same_v
<
ComputeDataType
,
I
8
>
||
is_same_v
<
ComputeDataType
,
I32
>
||
is_same_v
<
ComputeDataType
,
int
>
,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!"
);
auto
expo
=
std
::
log2
(
std
::
abs
(
max_possible_num
));
double
compute_error
=
0
;
...
...
@@ -111,10 +113,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
compute_error
=
std
::
pow
(
2
,
expo
-
NumericUtils
<
ComputeDataType
>::
mant
)
*
0.5
;
}
static_assert
(
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F
16
>
||
is_same_v
<
OutDataType
,
B
F16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I
32
>
||
is_same_v
<
OutDataType
,
int
>
,
static_assert
(
is_same_v
<
OutDataType
,
F
4
>
||
is_same_v
<
OutDataType
,
F
8
>
||
is_same_v
<
OutDataType
,
F16
>
||
is_same_v
<
OutDataType
,
BF16
>
||
is_same_v
<
OutDataType
,
F32
>
||
is_same_v
<
OutDataType
,
I
8
>
||
is_same_v
<
OutDataType
,
I32
>
||
is_same_v
<
OutDataType
,
int
>
,
"Warning: Unhandled OutDataType for setting up the absolute threshold!"
);
double
output_error
=
0
;
if
constexpr
(
is_same_v
<
OutDataType
,
I8
>
||
is_same_v
<
OutDataType
,
I32
>
||
...
...
@@ -128,10 +130,10 @@ double get_absolute_threshold(const double max_possible_num, const int number_of
}
double
midway_error
=
std
::
max
(
compute_error
,
output_error
);
static_assert
(
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F
16
>
||
is_same_v
<
AccDataType
,
B
F16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I
32
>
||
is_same_v
<
AccDataType
,
int
>
,
static_assert
(
is_same_v
<
AccDataType
,
F
4
>
||
is_same_v
<
AccDataType
,
F
8
>
||
is_same_v
<
AccDataType
,
F16
>
||
is_same_v
<
AccDataType
,
BF16
>
||
is_same_v
<
AccDataType
,
F32
>
||
is_same_v
<
AccDataType
,
I
8
>
||
is_same_v
<
AccDataType
,
I32
>
||
is_same_v
<
AccDataType
,
int
>
,
"Warning: Unhandled AccDataType for setting up the absolute threshold!"
);
double
acc_error
=
0
;
if
constexpr
(
is_same_v
<
AccDataType
,
I8
>
||
is_same_v
<
AccDataType
,
I32
>
||
...
...
@@ -450,5 +452,54 @@ check_err(const Range& out,
return
res
;
}
template
<
typename
Range
,
typename
RefRange
>
std
::
enable_if_t
<
(
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
ranges
::
range_value_t
<
RefRange
>>
&&
std
::
is_same_v
<
ranges
::
range_value_t
<
Range
>
,
f4_t
>
),
bool
>
check_err
(
const
Range
&
out
,
const
RefRange
&
ref
,
const
std
::
string
&
msg
=
"Error: Incorrect results!"
,
double
rtol
=
0.5
,
double
atol
=
0.5
)
{
if
(
out
.
size
()
!=
ref
.
size
())
{
std
::
cerr
<<
msg
<<
" out.size() != ref.size(), :"
<<
out
.
size
()
<<
" != "
<<
ref
.
size
()
<<
std
::
endl
;
return
false
;
}
bool
res
{
true
};
int
err_count
=
0
;
double
err
=
0
;
double
max_err
=
std
::
numeric_limits
<
float
>::
min
();
for
(
std
::
size_t
i
=
0
;
i
<
ref
.
size
();
++
i
)
{
const
double
o
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
out
),
i
));
const
double
r
=
type_convert
<
float
>
(
*
std
::
next
(
std
::
begin
(
ref
),
i
));
err
=
std
::
abs
(
o
-
r
);
if
(
err
>
atol
+
rtol
*
std
::
abs
(
r
)
||
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
))
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
}
res
=
false
;
}
}
if
(
!
res
)
{
std
::
cerr
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
"max err: "
<<
max_err
<<
" number of errors: "
<<
err_count
<<
std
::
endl
;
}
return
res
;
}
}
// namespace utils
}
// namespace ck
include/ck/library/utility/host_tensor_generator.hpp
View file @
f23a2e2a
...
...
@@ -69,6 +69,18 @@ struct GeneratorTensor_1<ck::f8_t>
};
#endif
template
<
>
struct
GeneratorTensor_1
<
ck
::
f4_t
>
{
float
value
=
1.0
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
return
ck
::
type_convert
<
ck
::
f4_t
>
(
value
);
}
};
template
<
>
struct
GeneratorTensor_1
<
int8_t
>
{
...
...
@@ -183,6 +195,20 @@ struct GeneratorTensor_2<ck::bf8_t>
};
#endif
template
<
>
struct
GeneratorTensor_2
<
ck
::
f4_t
>
{
int
min_value
=
0
;
int
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
(
std
::
rand
()
%
(
max_value
-
min_value
))
+
min_value
;
return
ck
::
type_convert
<
ck
::
f4_t
>
(
tmp
);
}
};
template
<
typename
T
>
struct
GeneratorTensor_3
{
...
...
@@ -253,6 +279,23 @@ struct GeneratorTensor_3<ck::bf8_t>
};
#endif
template
<
>
struct
GeneratorTensor_3
<
ck
::
f4_t
>
{
float
min_value
=
0
;
float
max_value
=
1
;
template
<
typename
...
Is
>
ck
::
f4_t
operator
()(
Is
...)
{
float
tmp
=
float
(
std
::
rand
())
/
float
(
RAND_MAX
);
float
fp32_tmp
=
min_value
+
tmp
*
(
max_value
-
min_value
);
return
ck
::
type_convert
<
ck
::
f4_t
>
(
fp32_tmp
);
}
};
template
<
typename
T
>
struct
GeneratorTensor_4
{
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7r2.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -131,7 +131,7 @@ struct ThreadGroupTensorSliceTransfer_v7r2
}
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
template
<
typename
DstBuffers
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunWrite
(
const
DstDescs
&
dst_descs
,
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#endif
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -18,6 +20,7 @@ enum struct ConvolutionForwardSpecialization
Filter3x3
,
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getConvForwardSpecializationString
(
const
ConvolutionForwardSpecialization
&
s
)
{
switch
(
s
)
...
...
@@ -30,6 +33,7 @@ inline std::string getConvForwardSpecializationString(const ConvolutionForwardSp
default:
return
"Unrecognized specialization!"
;
}
}
#endif
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <string>
#include <sstream>
#include <regex>
#include <optional>
#include "ck/stream_config.hpp"
#endif
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
#ifndef CK_CODE_GEN_RTC
#define GET_OBJECT_NAME_IMLP \
std::optional<std::string> GetObjectName() const override \
{ \
...
...
@@ -41,7 +43,9 @@ namespace device {
}
#define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
#endif
#ifndef CK_CODE_GEN_RTC
struct
BaseArgument
{
BaseArgument
()
=
default
;
...
...
@@ -66,13 +70,14 @@ struct BaseInvoker
virtual
~
BaseInvoker
()
{}
};
#endif
struct
BaseOperator
{
BaseOperator
()
=
default
;
BaseOperator
(
const
BaseOperator
&
)
=
default
;
BaseOperator
&
operator
=
(
const
BaseOperator
&
)
=
default
;
#ifndef CK_CODE_GEN_RTC
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
...
...
@@ -100,7 +105,7 @@ struct BaseOperator
assert
(
p_arg
);
p_arg
->
p_workspace_
=
p_workspace
;
}
#endif
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
View file @
f23a2e2a
...
...
@@ -44,6 +44,48 @@ struct DeviceBatchedGemm : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
BScaleType
,
typename
CDataType
,
index_t
ScaleBlockN
,
index_t
ScaleBlockK
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemmV2BScale
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
ck
::
index_t
StrideScaleB
,
ck
::
index_t
BatchStrideA
,
ck
::
index_t
BatchStrideB
,
ck
::
index_t
BatchStrideC
,
ck
::
index_t
BatchStrideScaleB
,
const
void
*
p_b_scale
,
ck
::
index_t
Batch
,
ck
::
index_t
KBatch
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
bool
GetPermuteB
()
=
0
;
virtual
ck
::
index_t
GetKPerBlock
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
...
...
include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2023-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <array>
#endif
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
...
...
@@ -13,8 +15,13 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
/**
* \brief Grouped Convolution Forward
...
...
@@ -72,12 +79,18 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static_assert
(
NumDTensor
==
DsLayout
::
Size
(),
"wrong! Inconsistent NumDTensor"
);
#ifdef CK_CODE_GEN_RTC
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#else
// If DataType is tuple, user has to pass std::array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiA
,
std
::
array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
ck
::
conditional_t
<
isMultiB
,
std
::
array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
#endif
#ifndef CK_CODE_GEN_RTC
/**
* \brief Make argument pointer for grouped conv fwd.
...
...
@@ -150,6 +163,7 @@ struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
const
CDEElementwiseOperation
&
cde_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
#endif
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/gemm_specialization.hpp
View file @
f23a2e2a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -29,6 +29,7 @@ enum struct GemmSpecialization
MNKOPadding
,
};
#ifndef CK_CODE_GEN_RTC
inline
std
::
string
getGemmSpecializationString
(
const
GemmSpecialization
&
s
)
{
switch
(
s
)
...
...
@@ -52,6 +53,7 @@ inline std::string getGemmSpecializationString(const GemmSpecialization& s)
default:
return
"Unrecognized specialization!"
;
}
}
#endif
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
View file @
f23a2e2a
...
...
@@ -3,11 +3,17 @@
#pragma once
#ifndef CK_CODE_GEN_RTC
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include <stdio.h>
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#endif
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
...
...
@@ -15,15 +21,12 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -91,8 +94,7 @@ __device__ void device_grouped_conv_fwd_multiple_abd_xdl_cshuffle(
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -259,8 +261,13 @@ __global__ void
}
// namespace
#ifdef CK_CODE_GEN_RTC
template
<
typename
T
>
using
is_tuple
=
decltype
(
ck
::
declval
<
T
&>
().
IsTuple
());
#else
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
#endif
//
// @brief Device Convolution operation.
...
...
@@ -429,8 +436,8 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
using
GemmADataType
=
ck
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
ck
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
...
...
@@ -449,15 +456,13 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
ck
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
using
APointers
=
ck
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
ck
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
...
...
@@ -812,7 +817,6 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
// FIXME: layout
if
constexpr
(
is_same_v
<
DLayout
,
ctc
::
G_NW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
G_NDHW_K
>
||
is_same_v
<
DLayout
,
ctc
::
GNWK
>
||
...
...
@@ -965,18 +969,18 @@ struct CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
std
::
a
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_strides_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
conv_filter_dilations_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_left_pads_i32
;
ck
::
A
rray
<
index_t
,
NDimSpatial
>
input_right_pads_i32
;
array_convert
(
a_g_n_c_wis_lengths_i32
,
a_g_n_c_wis_lengths
);
array_convert
(
a_g_n_c_wis_strides_i32
,
a_g_n_c_wis_strides
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
f23a2e2a
...
...
@@ -56,8 +56,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
f23a2e2a
...
...
@@ -74,8 +74,7 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2ETileMap
block_2_etile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_gemm_xdl_cshuffle.hpp
View file @
f23a2e2a
...
...
@@ -60,8 +60,7 @@ __global__ void
const
index_t
batch_count
,
const
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -108,7 +107,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
batch_count
;
ignore
=
compute_base_ptr_of_batch
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
// Computes C = A * B0 * B1
...
...
Prev
1
2
3
4
5
6
7
8
9
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment