Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
df35f46d
Commit
df35f46d
authored
Oct 07, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
9c0811f3
7733ae16
Changes
69
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
378 additions
and
75 deletions
+378
-75
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+3
-1
example/ck_tile/04_img2col/CMakeLists.txt
example/ck_tile/04_img2col/CMakeLists.txt
+3
-0
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+12
-0
example/ck_tile/04_img2col/image_to_column.cpp
example/ck_tile/04_img2col/image_to_column.cpp
+170
-0
example/ck_tile/04_img2col/image_to_column.hpp
example/ck_tile/04_img2col/image_to_column.hpp
+105
-0
example/ck_tile/CMakeLists.txt
example/ck_tile/CMakeLists.txt
+1
-0
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+9
-9
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+6
-6
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+29
-29
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+12
-1
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
No files found.
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
df35f46d
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType
,
YDataType
,
MeanDataType
,
MeanDataType
,
InvStdDataType
,
InvStdDataType
,
Shape
>
;
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
...
...
example/ck_tile/04_img2col/CMakeLists.txt
0 → 100644
View file @
df35f46d
# not using add_example_executable() to add this target, since we don't want this to have
# to be included in "make all/install/check"
add_executable
(
tile_example_img2col EXCLUDE_FROM_ALL image_to_column.cpp
)
example/ck_tile/04_img2col/README.md
0 → 100644
View file @
df35f46d
# Image to Column
This folder contains example for Image to Column using ck_tile tile-programming implementation.
## build
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
make tile_example_img2col -j
```
This will result in an executable
`build/bin/tile_example_img2col`
example/ck_tile/04_img2col/image_to_column.cpp
0 → 100644
View file @
df35f46d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstring>
#include "ck_tile/host.hpp"
#include "image_to_column.hpp"
// Host API implementation
template
<
>
float
image_to_column
(
const
image_to_column_traits
&
traits
,
const
image_to_column_args
<
2
>&
args
,
const
ck_tile
::
stream_config
&
stream_conf
)
{
if
(
traits
.
data_type
.
compare
(
"fp16"
)
==
0
)
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
constexpr
ck_tile
::
index_t
VectorSize
=
8
;
using
thread_tile
=
ck_tile
::
sequence
<
8
,
8
>
;
using
warp_tile
=
ck_tile
::
sequence
<
64
,
64
>
;
using
block_tile
=
ck_tile
::
sequence
<
128
,
128
>
;
using
Shape
=
ck_tile
::
TileImageToColumnShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
PipelineProblem
=
ck_tile
::
BlockImageToColumnProblem
<
InDataType
,
OutDataType
,
Shape
,
NDimSpatial
,
VectorSize
,
VectorSize
>
;
using
Kernel
=
ck_tile
::
ImageToColumn
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
args
.
p_in
,
args
.
p_out
,
args
.
G
,
args
.
N
,
args
.
C
,
args
.
input_spatial_lengths
,
args
.
filter_spatial_lengths
,
args
.
output_spatial_lengths
,
args
.
image_g_n_c_wis_strides
,
args
.
gemm_g_m_k_strides
,
args
.
conv_filter_strides
,
args
.
conv_filter_dilations
,
args
.
input_left_pads
,
args
.
input_right_pads
);
const
dim3
grids
=
Kernel
::
GridSize
(
args
.
N
*
args
.
output_spatial_lengths
[
0
]
*
args
.
output_spatial_lengths
[
1
],
args
.
filter_spatial_lengths
[
0
]
*
args
.
filter_spatial_lengths
[
1
]
*
args
.
C
,
args
.
G
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
2
;
float
ave_time
=
ck_tile
::
launch_kernel
(
stream_conf
,
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
return
ave_time
;
}
return
0
;
}
int
main
(
int
argc
,
char
*
argv
[])
{
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
ExecutionConfig
config
;
ck_tile
::
conv
::
ConvParam
conv_params
=
DefaultConvParams
;
if
(
!
parse_cmd_args
(
argc
,
argv
,
config
,
conv_params
))
{
return
EXIT_FAILURE
;
}
if
(
conv_params
.
num_dim_spatial_
!=
NDimSpatial
)
{
std
::
cerr
<<
"unsupported # of spatial dimensions"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
using
InDataType
=
ck_tile
::
half_t
;
using
OutDataType
=
ck_tile
::
half_t
;
using
ImLayout
=
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
;
const
auto
G
=
conv_params
.
G_
;
const
auto
N
=
conv_params
.
N_
;
const
auto
C
=
conv_params
.
C_
;
const
ck_tile
::
long_index_t
NHoWo
=
N
*
std
::
accumulate
(
conv_params
.
output_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
ck_tile
::
long_index_t
CYX
=
C
*
std
::
accumulate
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
auto
in_desc
=
ck_tile
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ImLayout
>
(
conv_params
);
const
auto
out_desc
=
ck_tile
::
HostTensorDescriptor
({
G
,
NHoWo
,
CYX
});
// host verify
ck_tile
::
HostTensor
<
InDataType
>
in
(
in_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_device
(
out_desc
);
ck_tile
::
HostTensor
<
OutDataType
>
out_host
(
out_desc
);
switch
(
config
.
init_method
)
{
case
0
:
break
;
case
1
:
ck_tile
::
FillUniformDistributionIntegerValue
<
InDataType
>
{
-
5.
f
,
5.
f
}(
in
);
break
;
default:
ck_tile
::
FillUniformDistribution
<
InDataType
>
{
-
0.5
,
0.5
}(
in
);
break
;
}
ck_tile
::
DeviceMem
in_device_buf
(
in
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
out_device_buf
(
out_device
.
get_element_space_size_in_bytes
());
in_device_buf
.
ToDevice
(
in
.
data
());
image_to_column_traits
traits
{
"fp16"
};
image_to_column_args
<
NDimSpatial
>
args
{
in_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
G
,
N
,
C
,
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
filter_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
output_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
(
in_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
3
>
(
out_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_strides_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_dilations_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_left_pads_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_right_pads_
)};
float
ave_time
=
image_to_column
(
traits
,
args
,
ck_tile
::
stream_config
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
num_btype
=
G
*
NHoWo
*
CYX
*
(
sizeof
(
OutDataType
)
+
sizeof
(
InDataType
));
float
gb_per_sec
=
num_btype
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
config
.
do_verification
)
{
// reference
ck_tile
::
reference_im2col
<
InDataType
,
OutDataType
,
NDimSpatial
>
(
in
,
out_host
,
conv_params
);
out_device_buf
.
FromDevice
(
out_device
.
data
());
pass
=
ck_tile
::
check_err
(
out_device
,
out_host
);
std
::
cout
<<
"valid:"
<<
(
pass
?
"y"
:
"n"
)
<<
std
::
endl
;
}
return
!
pass
;
}
example/ck_tile/04_img2col/image_to_column.hpp
0 → 100644
View file @
df35f46d
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
#include <string>
#define DefaultConvParams \
ck_tile::conv::ConvParam \
{ \
2, 2, 32, 32, 32, {4, 4}, {64, 64}, {1, 1}, {1, 1}, {0, 0}, { 0, 0 } \
}
struct
ExecutionConfig
final
{
bool
do_verification
=
true
;
int
init_method
=
1
;
bool
time_kernel
=
false
;
};
inline
void
print_help_msg
()
{
std
::
cerr
<<
"arg1: verification (0=no, 1=yes)
\n
"
<<
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
<<
"arg3: time kernel (0=no, 1=yes)
\n
"
<<
ck_tile
::
conv
::
get_conv_param_parser_helper_msg
()
<<
std
::
endl
;
}
inline
bool
parse_cmd_args
(
int
argc
,
char
*
argv
[],
ExecutionConfig
&
config
,
ck_tile
::
conv
::
ConvParam
&
conv_params
)
{
constexpr
int
num_execution_config_args
=
3
;
// arguments for do_verification, init_method, time_kernel
constexpr
int
num_conv_param_leading_args
=
5
;
// arguments for num_dim_spatial_, G_, N_, K_, C_
constexpr
int
threshold_to_catch_partial_args
=
1
+
num_execution_config_args
;
constexpr
int
threshold_to_catch_all_args
=
threshold_to_catch_partial_args
+
num_conv_param_leading_args
;
if
(
argc
==
1
)
{
// use default
config
=
ExecutionConfig
{};
}
// catch only ExecutionConfig arguments
else
if
(
argc
==
threshold_to_catch_partial_args
)
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
// catch both ExecutionConfig & ConvParam arguments
else
if
(
threshold_to_catch_all_args
<
argc
&&
((
argc
-
threshold_to_catch_all_args
)
%
3
==
0
))
{
config
.
do_verification
=
std
::
stoi
(
argv
[
1
]);
config
.
init_method
=
std
::
stoi
(
argv
[
2
]);
config
.
time_kernel
=
std
::
stoi
(
argv
[
3
]);
const
ck_tile
::
index_t
num_dim_spatial
=
std
::
stoi
(
argv
[
4
]);
conv_params
=
ck_tile
::
conv
::
parse_conv_param
(
num_dim_spatial
,
threshold_to_catch_partial_args
,
argv
);
}
else
{
print_help_msg
();
return
false
;
}
return
true
;
}
struct
image_to_column_traits
{
std
::
string
data_type
;
};
template
<
ck_tile
::
index_t
NDimSpatial
>
struct
image_to_column_args
{
const
void
*
p_in
;
void
*
p_out
;
const
ck_tile
::
long_index_t
G
;
const
ck_tile
::
long_index_t
N
;
const
ck_tile
::
long_index_t
C
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
ck_tile
::
array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
// host API
template
<
ck_tile
::
index_t
NDimSpatial
>
float
image_to_column
(
const
image_to_column_traits
&
,
const
image_to_column_args
<
NDimSpatial
>&
,
const
ck_tile
::
stream_config
&
);
example/ck_tile/CMakeLists.txt
View file @
df35f46d
...
@@ -5,3 +5,4 @@ include_directories(AFTER
...
@@ -5,3 +5,4 @@ include_directories(AFTER
add_subdirectory
(
01_fmha
)
add_subdirectory
(
01_fmha
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
02_layernorm2d
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
03_gemm
)
add_subdirectory
(
04_img2col
)
include/ck/config.h.in
View file @
df35f46d
...
@@ -97,13 +97,6 @@
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
//
// CK kernels which support XDL (MI series)
// CK kernels which support XDL (MI series)
//
//
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
df35f46d
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
1
>
()
__device__
constexpr
auto
TailScheduler
<
1
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
...
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}
}
template
<
>
template
<
>
__device__
static
constexpr
auto
TailScheduler
<
2
>
()
__device__
constexpr
auto
TailScheduler
<
2
>
()
{
{
// schedule
// schedule
constexpr
auto
num_ds_read_inst
=
constexpr
auto
num_ds_read_inst
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
df35f46d
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
df35f46d
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
df35f46d
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
xdlops_gemm
.
template
Run
<
>(
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
df35f46d
...
@@ -64,7 +64,7 @@ __global__ void
...
@@ -64,7 +64,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
df35f46d
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
skipped_group_count_
++
;
skipped_group_count_
++
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
df35f46d
...
@@ -109,7 +109,7 @@ __global__ void
...
@@ -109,7 +109,7 @@ __global__ void
N
=
gemm_desc_ptr
[
group_id
].
N
;
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
{
grid_size_grp
=
0
;
grid_size_grp
=
0
;
continue
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
df35f46d
...
@@ -68,7 +68,7 @@ __global__ void
...
@@ -68,7 +68,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
...
...
include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
View file @
df35f46d
...
@@ -324,55 +324,55 @@ struct DppSelector
...
@@ -324,55 +324,55 @@ struct DppSelector
static
constexpr
auto
GetDpp
();
static
constexpr
auto
GetDpp
();
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_8x32x2
;
return
DppInstr
::
dpp8_f16_8x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
8
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_8x16x2
;
return
DppInstr
::
dpp8_f16_8x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
16
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_16x16x2
;
return
DppInstr
::
dpp8_f16_16x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
constexpr
auto
GetDpp
<
half_t
,
32
,
8
>
()
{
{
return
DppInstr
::
dpp8_f16_32x8x2
;
return
DppInstr
::
dpp8_f16_32x8x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
1
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_1x32x2
;
return
DppInstr
::
dpp8_f16_1x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_2x32x2
;
return
DppInstr
::
dpp8_f16_2x32x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
2
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_2x16x2
;
return
DppInstr
::
dpp8_f16_2x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
16
>
()
{
{
return
DppInstr
::
dpp8_f16_4x16x2
;
return
DppInstr
::
dpp8_f16_4x16x2
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
constexpr
auto
GetDpp
<
half_t
,
4
,
32
>
()
{
{
return
DppInstr
::
dpp8_f16_4x32x2
;
return
DppInstr
::
dpp8_f16_4x32x2
;
}
}
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
df35f46d
...
@@ -415,7 +415,7 @@ struct WmmaSelector
...
@@ -415,7 +415,7 @@ struct WmmaSelector
static
constexpr
auto
GetWmma
();
static
constexpr
auto
GetWmma
();
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
...
@@ -425,7 +425,7 @@ struct WmmaSelector
...
@@ -425,7 +425,7 @@ struct WmmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
float
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
return
WmmaInstr
::
wmma_f32_16x16x16_bf16_gfx12
;
...
@@ -435,19 +435,19 @@ struct WmmaSelector
...
@@ -435,19 +435,19 @@ struct WmmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
half_t
,
half_t
,
half_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
return
WmmaInstr
::
wmma_f16_16x16x16_f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetWmma
<
bhalf_t
,
bhalf_t
,
bhalf_t
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
return
WmmaInstr
::
wmma_bf16_16x16x16_bf16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int8_t
,
int8_t
,
int
,
16
,
16
>
()
{
{
#ifdef __gfx12__
#ifdef __gfx12__
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu8_gfx12
;
...
@@ -458,7 +458,7 @@ struct WmmaSelector
...
@@ -458,7 +458,7 @@ struct WmmaSelector
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
template
<
>
static
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
constexpr
auto
GetWmma
<
int4_t
,
int4_t
,
int
,
16
,
16
>
()
{
{
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
return
WmmaInstr
::
wmma_i32_16x16x16_iu4
;
}
}
...
...
include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
View file @
df35f46d
...
@@ -651,97 +651,97 @@ struct MfmaSelector
...
@@ -651,97 +651,97 @@ struct MfmaSelector
static
constexpr
auto
GetMfma
();
static
constexpr
auto
GetMfma
();
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
constexpr
auto
GetMfma
<
double
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
return
MfmaInstr
::
mfma_f64_16x16x4f64
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
constexpr
auto
GetMfma
<
float
,
64
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
constexpr
auto
GetMfma
<
float
,
32
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
return
MfmaInstr
::
mfma_f32_32x32x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
constexpr
auto
GetMfma
<
float
,
16
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
return
MfmaInstr
::
mfma_f32_16x16x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
constexpr
auto
GetMfma
<
float
,
8
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
constexpr
auto
GetMfma
<
float
,
4
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
return
MfmaInstr
::
mfma_f32_4x4x1xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
constexpr
auto
GetMfma
<
float
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
return
MfmaInstr
::
mfma_f32_32x32x2xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
constexpr
auto
GetMfma
<
float
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
return
MfmaInstr
::
mfma_f32_16x16x4xf32
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
64
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
return
MfmaInstr
::
mfma_f32_32x32x4f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
half_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
return
MfmaInstr
::
mfma_f32_32x32x8f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
return
MfmaInstr
::
mfma_f32_16x16x16f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
16
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
return
MfmaInstr
::
mfma_f32_16x16x4f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
8
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
constexpr
auto
GetMfma
<
half_t
,
4
,
64
>
()
{
{
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
return
MfmaInstr
::
mfma_f32_4x4x4f16
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
32
,
32
>
()
{
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
return
MfmaInstr
::
mfma_f32_32x32x8bf16_1k
;
...
@@ -751,7 +751,7 @@ struct MfmaSelector
...
@@ -751,7 +751,7 @@ struct MfmaSelector
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bhalf_t
,
16
,
16
>
()
{
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
return
MfmaInstr
::
mfma_f32_16x16x16bf16_1k
;
...
@@ -762,72 +762,72 @@ struct MfmaSelector
...
@@ -762,72 +762,72 @@ struct MfmaSelector
#if defined(CK_USE_AMD_MFMA_GFX940)
#if defined(CK_USE_AMD_MFMA_GFX940)
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
return
MfmaInstr
::
mfma_i32_32x32x16i8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
return
MfmaInstr
::
mfma_i32_16x16x32i8
;
}
}
#else
#else
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
int8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
return
MfmaInstr
::
mfma_i32_32x32x8i8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
int8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
return
MfmaInstr
::
mfma_i32_16x16x16i8
;
}
}
#endif
#endif
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
return
MfmaInstr
::
mfma_f32_32x32x16f8f8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8f8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
return
MfmaInstr
::
mfma_f32_32x32x16bf8bf8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8bf8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
32
,
32
,
bf8_t
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
return
MfmaInstr
::
mfma_f32_32x32x16f8bf8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
constexpr
auto
GetMfma
<
f8_t
,
16
,
16
,
bf8_t
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
return
MfmaInstr
::
mfma_f32_16x16x32f8bf8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
32
,
32
,
f8_t
>
()
{
{
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
return
MfmaInstr
::
mfma_f32_32x32x16bf8f8
;
}
}
template
<
>
template
<
>
static
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
constexpr
auto
GetMfma
<
bf8_t
,
16
,
16
,
f8_t
>
()
{
{
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
return
MfmaInstr
::
mfma_f32_16x16x32bf8f8
;
}
}
...
...
include/ck_tile/core/container/array.hpp
View file @
df35f46d
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <initializer_list>
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
...
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
return
!
(
a
==
b
);
return
!
(
a
==
b
);
}
}
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
std
::
vector
<
X
>&
x
)
{
array
<
T
,
N
>
arr
;
static_for
<
0
,
N
,
1
>
{}([
&
x
,
&
arr
](
auto
i
)
{
arr
(
i
)
=
x
[
i
];
});
return
arr
;
}
template
<
typename
T
,
index_t
N
,
typename
X
>
template
<
typename
T
,
index_t
N
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
to_array
(
const
X
&
x
)
{
{
...
...
include/ck_tile/host.hpp
View file @
df35f46d
...
@@ -5,6 +5,8 @@
...
@@ -5,6 +5,8 @@
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/arg_parser.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
#include "ck_tile/host/convolution_parameter.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/device_memory.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment