Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e4112de7
Commit
e4112de7
authored
May 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
a6ef5c39
fd72380a
Changes
39
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2813 additions
and
736 deletions
+2813
-736
CMakeLists.txt
CMakeLists.txt
+10
-6
client_example/07_grouped_convnd_fwd/CMakeLists.txt
client_example/07_grouped_convnd_fwd/CMakeLists.txt
+19
-1
client_example/07_grouped_convnd_fwd/common.hpp
client_example/07_grouped_convnd_fwd/common.hpp
+304
-0
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
+11
-201
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
+11
-169
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp
..._example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp
...mple/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp
..._example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp
...mple/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp
+0
-0
client_example/16_convnd_fwd/CMakeLists.txt
client_example/16_convnd_fwd/CMakeLists.txt
+0
-16
example/CMakeLists.txt
example/CMakeLists.txt
+31
-4
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+14
-9
include/ck/tensor_description/multi_index_transform.hpp
include/ck/tensor_description/multi_index_transform.hpp
+11
-4
include/ck/tensor_description/multi_index_transform_helper.hpp
...de/ck/tensor_description/multi_index_transform_helper.hpp
+8
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+1010
-304
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-2
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
...ion/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
+1369
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+11
-13
No files found.
CMakeLists.txt
View file @
e4112de7
...
@@ -23,7 +23,7 @@ endif()
...
@@ -23,7 +23,7 @@ endif()
set
(
version 1.1.0
)
set
(
version 1.1.0
)
# Check support for CUDA/HIP in Cmake
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX
)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX
HIP
)
include
(
CTest
)
include
(
CTest
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
#Setting GPU_TARGETS on command line will override this list
if
(
NOT PROFILER_ONLY
)
if
(
NOT PROFILER_ONLY
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
TARGETS
"
gfx900;gfx906;
gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
else
()
else
()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
...
@@ -135,12 +135,10 @@ endif()
...
@@ -135,12 +135,10 @@ endif()
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
set
(
AMDGPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
message
(
"Building CK for the following targets:
${
GPU_TARGETS
}
"
)
message
(
"Building CK for the following targets:
${
GPU_TARGETS
}
"
)
else
()
else
()
message
(
"Building CK for the
following
targets:
${
AMD
GPU_TARGETS
}
"
)
message
(
"Building CK for the
default
targets:
${
DEFAULT_
GPU_TARGETS
}
"
)
endif
()
endif
()
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
...
@@ -225,7 +223,13 @@ link_libraries(Threads::Threads)
...
@@ -225,7 +223,13 @@ link_libraries(Threads::Threads)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
message
(
"CMAKE_CXX_COMPILER:
${
CMAKE_CXX_COMPILER
}
"
)
## HIP
set
(
CMAKE_HIP_PLATFORM amd
)
set
(
CMAKE_HIP_COMPILER
${
CMAKE_CXX_COMPILER
}
)
set
(
CMAKE_HIP_EXTENSIONS ON
)
message
(
"CMAKE_HIP_COMPILER:
${
CMAKE_HIP_COMPILER
}
"
)
## OpenMP
## OpenMP
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
...
...
client_example/07_grouped_convnd_fwd/CMakeLists.txt
View file @
e4112de7
...
@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations
)
endif
()
\ No newline at end of file
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
endif
()
client_example/07_grouped_convnd_fwd/common.hpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
template
<
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetFlops
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
)
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
ck
::
index_t
G
=
weights_lengths
[
0
];
ck
::
index_t
N
=
output_lengths
[
1
];
ck
::
index_t
K
=
weights_lengths
[
1
];
ck
::
index_t
C
=
weights_lengths
[
2
];
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
output_lengths
),
NumNonSpatialDim
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
weights_lengths
),
NumNonSpatialDim
),
std
::
end
(
weights_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
InDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetInputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
input_lengths
)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
std
::
accumulate
(
std
::
begin
(
input_lengths
),
std
::
end
(
input_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
WeiDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetWeightByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
std
::
accumulate
(
std
::
begin
(
weights_lengths
),
std
::
end
(
weights_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetOutputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
}
template
<
ck
::
index_t
NumDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
ck
::
index_t
NumNonSpatialDim
=
3
,
typename
AComputeType
=
InDataType
,
typename
BComputeType
=
AComputeType
>
bool
run_grouped_conv_fwd
(
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
in_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
wei_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
out_lengths
)
{
std
::
size_t
in_mem_size
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
in_lengths
);
std
::
size_t
wei_mem_size
=
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
wei_lengths
);
std
::
size_t
out_mem_size
=
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
out_lengths
);
SimpleDeviceMem
in
(
in_mem_size
);
SimpleDeviceMem
wei
(
wei_mem_size
);
SimpleDeviceMem
out
(
out_mem_size
);
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
in_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
wei_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
out_strides
;
in_strides
.
fill
(
0
);
wei_strides
.
fill
(
0
);
out_strides
.
fill
(
0
);
in_strides
.
back
()
=
1
;
wei_strides
.
back
()
=
1
;
out_strides
.
back
()
=
1
;
std
::
partial_sum
(
rbegin
(
in_lengths
),
std
::
prev
(
rend
(
in_lengths
)),
std
::
next
(
rbegin
(
in_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
wei_lengths
),
std
::
prev
(
rend
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
out_lengths
),
std
::
prev
(
rend
(
out_lengths
)),
std
::
next
(
rbegin
(
out_strides
)),
std
::
multiplies
<>
{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std
::
rotate
(
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
2
),
rend
(
in_lengths
));
std
::
rotate
(
rbegin
(
in_lengths
),
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
2
),
rend
(
in_strides
));
std
::
rotate
(
rbegin
(
in_strides
),
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
2
),
rend
(
out_lengths
));
std
::
rotate
(
rbegin
(
out_lengths
),
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
2
),
rend
(
out_strides
));
std
::
rotate
(
rbegin
(
out_strides
),
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
NumDimSpatial
+
1
));
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
;
conv_filter_strides
.
fill
(
1
);
conv_filter_dilations
.
fill
(
1
);
input_left_pads
.
fill
(
1
);
input_right_pads
.
fill
(
1
);
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
out_lengths
,
wei_lengths
);
std
::
size_t
num_bytes
=
in_mem_size
+
wei_mem_size
+
out_mem_size
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
AComputeType
,
BComputeType
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
false
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
true
;
}
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "common.hpp"
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
...
@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
int
main
()
{
{
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
in_lengths
{
G
,
N
,
Wi
,
C
};
return
run_grouped_conv_fwd
<
NumDimSpatial
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
in_strides
{
0
,
0
,
0
,
1
};
InDataType
,
WeiDataType
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
wei_lengths
{
G
,
K
,
X
,
C
};
OutDataType
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
wei_strides
{
0
,
0
,
0
,
1
};
InLayout
,
WeiLayout
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
out_lengths
{
G
,
N
,
Wo
,
K
};
OutLayout
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
out_strides
{
0
,
0
,
0
,
1
};
3
>
({
N
,
Wi
,
G
,
C
},
{
G
,
K
,
X
,
C
},
{
N
,
Wo
,
G
,
K
})
?
EXIT_SUCCESS
std
::
partial_sum
(
rbegin
(
in_lengths
),
:
EXIT_FAILURE
;
std
::
prev
(
rend
(
in_lengths
)),
std
::
next
(
rbegin
(
in_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
wei_lengths
),
std
::
prev
(
rend
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
out_lengths
),
std
::
prev
(
rend
(
out_lengths
)),
std
::
next
(
rbegin
(
out_strides
)),
std
::
multiplies
<>
{});
// transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW
std
::
rotate
(
rbegin
(
in_lengths
),
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
in_strides
),
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
out_lengths
),
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
out_strides
),
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
NumDimSpatial
+
1
));
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
};
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
G
*
N
*
Wi
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
N
*
K
*
C
*
Wo
*
X
;
std
::
size_t
num_bytes
=
sizeof
(
InDataType
)
*
G
*
N
*
Wi
*
C
+
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
+
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
}
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "common.hpp"
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
...
@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
static
constexpr
ck
::
index_t
Ho
=
28
;
// output H
static
constexpr
ck
::
index_t
Ho
=
28
;
// output H
static
constexpr
ck
::
index_t
Wo
=
28
;
// output W
static
constexpr
ck
::
index_t
Wo
=
28
;
// output W
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
int
main
()
{
{
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space
return
run_grouped_conv_fwd
<
NumDimSpatial
,
// However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW
InDataType
,
// Hence, we need to adjust the order of stride
WeiDataType
,
std
::
array
<
ck
::
index_t
,
5
>
in_lengths
{
G
,
N
,
C
,
Hi
,
Wi
};
OutDataType
,
std
::
array
<
ck
::
index_t
,
5
>
in_strides
{
C
,
Hi
*
Wi
*
G
*
C
,
1
,
Wi
*
G
*
C
,
G
*
C
};
InLayout
,
std
::
array
<
ck
::
index_t
,
5
>
wei_lengths
{
G
,
K
,
C
,
Y
,
X
};
WeiLayout
,
std
::
array
<
ck
::
index_t
,
5
>
wei_strides
{
K
*
Y
*
X
*
C
,
Y
*
X
*
C
,
1
,
X
*
C
,
C
};
OutLayout
,
std
::
array
<
ck
::
index_t
,
5
>
out_lengths
{
G
,
N
,
K
,
Ho
,
Wo
};
3
>
({
N
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Y
,
X
,
C
},
{
N
,
Ho
,
Wo
,
G
,
K
})
std
::
array
<
ck
::
index_t
,
5
>
out_strides
{
C
,
Ho
*
Wo
*
G
*
C
,
1
,
Wo
*
G
*
C
,
G
*
C
};
?
EXIT_SUCCESS
:
EXIT_FAILURE
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
};
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
N
*
K
*
C
*
Ho
*
Wo
*
Y
*
X
;
std
::
size_t
num_bytes
=
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
+
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
+
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
}
client_example/
16
_convnd_fwd/conv3d_fwd_bf8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_bf8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_bf8_fp8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_bf8_fp8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_fp8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_fp8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_fp8_bf8.cpp
View file @
e4112de7
File moved
client_example/16_convnd_fwd/CMakeLists.txt
View file @
e4112de7
...
@@ -7,22 +7,6 @@ endif()
...
@@ -7,22 +7,6 @@ endif()
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp
)
add_executable
(
client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
endif
()
if
((
DTYPES MATCHES
"fp32"
)
OR NOT DEFINED DTYPES
)
if
((
DTYPES MATCHES
"fp32"
)
OR NOT DEFINED DTYPES
)
...
...
example/CMakeLists.txt
View file @
e4112de7
...
@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
EX_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
EX_TARGETS
${
GPU_TARGETS
}
)
endif
()
#Do not build any DL examples if DL_KERNELS not set
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
if
(
FILE_NAME
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
set_property
(
TARGET
${
EXAMPLE_NAME
}
PROPERTY HIP_ARCHITECTURES
${
EX_TARGETS
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
...
@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
EX_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
EX_TARGETS
${
GPU_TARGETS
}
)
endif
()
#Do not build any DL examples if DL_KERNELS not set
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
if
(
FILE_NAME
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
set_property
(
TARGET
${
EXAMPLE_NAME
}
PROPERTY HIP_ARCHITECTURES
${
EX_TARGETS
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
set
(
result 0
)
set
(
result 0
)
endif
()
endif
()
...
...
include/ck/host_utility/flush_cache.hpp
View file @
e4112de7
...
@@ -104,14 +104,19 @@ inline void flush_icache()
...
@@ -104,14 +104,19 @@ inline void flush_icache()
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
// if TimePrePress == false, return time does not include preprocess's time
// if TimePrePress == false, return time does not include preprocess's time
template
<
bool
TimePreprocess
,
typename
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
bool
TimePreprocess
,
typename
GemmArgs
,
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
std
::
size_t
lds_byte
,
Args
&
args
)
GemmArgs
&
gemm_args
,
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
#define MEDIAN 1
#define MEDIAN 1
...
@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
...
@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
preprocess
();
}
}
// run real kernel
// run real kernel
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
// end real kernel
// end real kernel
...
@@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
printf
(
"args.p_a_grid: %p, args.p_b_grid:%p
\n
"
,
printf
(
"
gemm_
args.p_a_grid: %p,
gemm_
args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
args
.
p_b_grid
));
static_cast
<
const
void
*>
(
gemm_
args
.
p_b_grid
));
}
}
}
}
...
@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
else
else
{
{
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
}
}
#else
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
...
...
include/ck/tensor_description/multi_index_transform.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -1952,7 +1952,7 @@ struct Modulo
...
@@ -1952,7 +1952,7 @@ struct Modulo
}
}
};
};
template
<
typename
LowLengths
>
template
<
typename
LowLengths
,
bool
ApplyModulo
>
struct
Xor
struct
Xor
{
{
using
LowerIndex
=
MultiIndex
<
2
>
;
using
LowerIndex
=
MultiIndex
<
2
>
;
...
@@ -1981,8 +1981,15 @@ struct Xor
...
@@ -1981,8 +1981,15 @@ struct Xor
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
idx_low
(
Number
<
1
>
{})
=
if
constexpr
(
ApplyModulo
)
idx_up
[
Number
<
1
>
{}]
^
(
idx_up
[
Number
<
0
>
{}]
%
up_lengths_
[
Number
<
1
>
{}]);
{
idx_low
(
Number
<
1
>
{})
=
idx_up
[
Number
<
1
>
{}]
^
(
idx_up
[
Number
<
0
>
{}]
%
up_lengths_
[
Number
<
1
>
{}]);
}
else
{
idx_low
(
Number
<
1
>
{})
=
idx_up
[
Number
<
1
>
{}]
^
idx_up
[
Number
<
0
>
{}];
}
}
}
template
<
typename
LowIdxDiff
,
template
<
typename
LowIdxDiff
,
...
...
include/ck/tensor_description/multi_index_transform_helper.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
...
@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_xor_with_modulo_transform
(
const
LowLengths
&
low_lengths
)
{
return
Xor
<
LowLengths
,
true
/*ApplyModulo*/
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
__host__
__device__
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
{
return
Xor
<
LowLengths
>
{
low_lengths
};
return
Xor
<
LowLengths
,
false
/*ApplyModulo*/
>
{
low_lengths
};
}
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e4112de7
...
@@ -53,8 +53,7 @@ __global__ void
...
@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
e4112de7
...
@@ -14,95 +14,137 @@
...
@@ -14,95 +14,137 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatA
,
typename
AGridDesc_AK0_M_K1
,
typename
FloatB
,
typename
BGridDesc_BK0_N_K1
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_B_K0_M_K1
,
typename
BGridDesc_B_K0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2CTileMap
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
>
index_t
NumBatchToMerge
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
TailNumber
TailNum
=
TailNumber
::
Full
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
#endif
kernel_batched_gemm_xdlops_bwd_weight
(
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
(
const
FloatA
*
__restrict__
p_a_grid
,
typename
GridwiseGemm
::
Argument
karg
,
const
FloatB
*
__restrict__
p_b_grid
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
FloatC
*
__restrict__
p_c_grid
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
index_t
batch_count
,
const
AGridDesc_B_K0_M_K1
a_b_k0_m_k1_grid_desc
,
const
BGridDesc_B_K0_N_K1
b_b_k0_n_k1_grid_desc
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2CTileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute
_p
t
r_
offset_of_batch
)
const
index_t
num_k
_p
e
r_
block
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
defined(__gfx94__))
const
index_t
num_blocks_per_batch
=
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
__shared__
FloatA
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatA
)];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
AGridDesc_AK0_M_K1
,
p_b_grid
+
b_batch_offset
,
BGridDesc_BK0_N_K1
,
p_c_grid
+
c_batch_offset
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
p_shared
,
HasMainKBlockLoop
,
a_b_k0_m_k1_grid_desc
,
CGlobalMemoryDataOperation
,
b_b_k0_n_k1_grid_desc
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
karg
.
p_b_grid
+
b_batch_offset
,
a_element_op
,
karg
.
p_c_grid
+
e_batch_offset
,
b_element_op
,
p_shared
,
c_element_op
,
karg
,
block_2_ctile_map
);
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
k_idx
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
karg
;
ignore
=
p_b_grid
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
ignore
=
p_c_grid
;
}
ignore
=
a_b_k0_m_k1_grid_desc
;
ignore
=
b_b_k0_n_k1_grid_desc
;
template
<
typename
GridwiseGemm
,
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
typename
AGridDesc_AK0_M_K1
,
ignore
=
a_element_op
;
typename
BGridDesc_BK0_N_K1
,
ignore
=
b_element_op
;
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
ignore
=
c_element_op
;
typename
ComputePtrOffsetOfBatch
,
ignore
=
batch_count
;
index_t
NumBatchToMerge
,
ignore
=
block_2_ctile_map
;
bool
HasMainKBlockLoop
,
ignore
=
compute_ptr_offset_of_batch
;
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
index_t
MinimumOccupancy
=
1
,
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
0
);
TailNumber
TailNum
=
TailNumber
::
Full
>
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
0
);
__global__
void
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
0
);
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
num_k_per_block
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// offset base pointer for each work-group
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
*
NumBatchToMerge
);
const
index_t
k_idx
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
num_k_per_block
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_1
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run_2Lds
<
AGridDesc_AK0_M_K1
,
BGridDesc_BK0_N_K1
,
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
HasMainKBlockLoop
,
CGlobalMemoryDataOperation
,
TailNum
>(
karg
.
p_a_grid
+
a_batch_offset
,
karg
.
p_b_grid
+
b_batch_offset
,
karg
.
p_c_grid
+
e_batch_offset
,
p_shared_0
,
p_shared_1
,
karg
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
k_idx
);
#else
ignore
=
karg
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
...
@@ -121,7 +163,7 @@ template <ck::index_t NDimSpatial,
...
@@ -121,7 +163,7 @@ template <ck::index_t NDimSpatial,
ck
::
index_t
BlockSize
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K
0
PerBlock
,
ck
::
index_t
KPerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
MPerXdl
,
ck
::
index_t
NPerXdl
,
ck
::
index_t
NPerXdl
,
...
@@ -145,8 +187,11 @@ template <ck::index_t NDimSpatial,
...
@@ -145,8 +187,11 @@ template <ck::index_t NDimSpatial,
index_t
CShuffleNXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
index_t
CBlockTransferScalarPerVector_NWaveNPerXdl
,
typename
ComputeTypeA
=
InDataType
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
typename
ComputeTypeB
=
ComputeTypeA
>
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v1
,
index_t
NumBatchToMerge
=
1
,
typename
ComputeTypeA
=
InDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
struct
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
:
public
DeviceGroupedConvBwdWeight
<
NDimSpatial
,
InLayout
,
InLayout
,
...
@@ -161,6 +206,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -161,6 +206,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
ComputeTypeA
,
ComputeTypeA
,
ComputeTypeB
>
ComputeTypeB
>
{
{
static_assert
(
is_same_v
<
InElementwiseOperation
,
element_wise
::
PassThrough
>
);
static_assert
(
is_same_v
<
WeiElementwiseOperation
,
element_wise
::
PassThrough
>
);
static_assert
(
is_same_v
<
OutElementwiseOperation
,
element_wise
::
PassThrough
>
);
using
DeviceOp
=
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
;
using
DeviceOp
=
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
;
using
ADataType
=
OutDataType
;
using
ADataType
=
OutDataType
;
...
@@ -183,101 +232,123 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -183,101 +232,123 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
static
constexpr
auto
conv_to_gemm_transformer_v2
=
TransformConvBwdWeightToGemmV2
<
NDimSpatial
,
MPerBlock
,
NPerBlock
,
K1Number
,
KPerBlock
/
K1Number
,
NumBatchToMerge
,
ConvBackwardWeightSpecialization
>
{};
static
constexpr
auto
conv_to_gemm_transformer_v1
=
TransformConvBwdWeightToGemm
<
NDimSpatial
,
TransformConvBwdWeightToGemm
<
NDimSpatial
,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K1Number
,
K1Number
,
K
0
PerBlock
,
KPerBlock
/
K1Number
,
ConvBackwardWeightSpecialization
>
{};
ConvBackwardWeightSpecialization
>
{};
// Bytes per 32 lds bank: 32 * 4 bytes
static
constexpr
GemmSpecialization
GemmSpec
=
GemmSpecialization
::
Default
;
static
constexpr
auto
BankLength
=
128
;
static
constexpr
auto
ElePerBank
=
BankLength
/
sizeof
(
ADataType
);
// M1 & M0
static
constexpr
auto
ABlockLdsM1PerBlock
=
ElePerBank
/
K1
;
static
constexpr
auto
ABlockLdsM0PerBlock
=
MPerBlock
/
ABlockLdsM1PerBlock
;
static
constexpr
auto
ABlockLdsM1Padding
=
4
;
// N1 & N0
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
constexpr
auto
BBlockLdsN1PerBlock
=
ElePerBank
/
K1
;
static
auto
GetABCGridDesc
()
static
constexpr
auto
BBlockLdsN0PerBlock
=
NPerBlock
/
BBlockLdsN1PerBlock
;
{
static
constexpr
auto
BBlockLdsN1Padding
=
4
;
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
batch
=
1
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
};
return
conv_to_gemm_transformer_v2
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>(
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
,
batch
);
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
GetABCGridDesc
()
static
auto
GetABCGridDesc
()
{
{
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
batch
=
1
;
const
ck
::
index_t
batch
=
1
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
conv_to_gemm_transformer
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
1
>(
return
conv_to_gemm_transformer
_v2
dim
,
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>(
dim
,
dim
,
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
batch
);
batch
);
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
auto
Get
AB
CGridDesc
()
static
auto
Get
Elementwise
CGridDesc
()
{
{
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
batch
=
1
;
const
ck
::
index_t
batch
=
1
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
};
return
conv_to_gemm_transformer
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>(
return
conv_to_gemm_transformer
_v1
dim
,
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
2
>(
dim
,
dim
,
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
batch
);
batch
)
[
I2
]
;
}
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
auto
Get
AB
CGridDesc
()
static
auto
Get
Elementwise
CGridDesc
()
{
{
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
dim
=
1
;
const
ck
::
index_t
batch
=
1
;
const
ck
::
index_t
batch
=
1
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
lengths
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
strides
{
1
,
1
,
1
,
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
params
{
1
,
1
,
1
};
return
conv_to_gemm_transformer
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>(
return
conv_to_gemm_transformer
_v1
dim
,
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
3
>(
dim
,
dim
,
dim
,
dim
,
dim
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
lengths
,
strides
,
strides
,
strides
,
strides
,
strides
,
strides
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
params
,
batch
);
batch
)
[
I2
]
;
}
}
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
using
ABCGridDescs
=
decltype
(
GetABCGridDesc
<
NDimSpatial
>
());
...
@@ -285,60 +356,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -285,60 +356,56 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
AGridDesc_K0_M_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I0
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
BGridDesc_K0_N_K1
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I1
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ABCGridDescs
{}[
I2
])
>
;
using
CElementwiseGridDesc_M_N
=
using
GridwiseGemm
=
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_bwd_weight
<
remove_cvref_t
<
decltype
(
GetElementwiseCGridDesc
<
NDimSpatial
>
())
>
;
BlockSize
,
ADataType
,
using
GridwiseGemm
=
BDataType
,
GridwiseGemm_xdl_cshuffle_v3
<
tensor_layout
::
gemm
::
RowMajor
,
AccDataType
,
tensor_layout
::
gemm
::
ColumnMajor
,
AccDataType
,
tensor_layout
::
gemm
::
RowMajor
,
InMemoryDataOperationEnum
::
AtomicAdd
,
ADataType
,
AGridDesc_K0_M_K1
,
BDataType
,
BGridDesc_K0_N_K1
,
AccDataType
,
CGridDesc_M_N
,
AccDataType
,
AElementwiseOperation
,
AccDataType
,
BElementwiseOperation
,
AElementwiseOperation
,
element_wise
::
PassThrough
,
BElementwiseOperation
,
MPerBlock
,
CDEElementwiseOperation
,
NPerBlock
,
GemmSpec
,
K0PerBlock
,
BlockSize
,
MPerXdl
,
MPerBlock
,
NPerXdl
,
NPerBlock
,
K1
,
KPerBlock
,
MXdlPerWave
,
K1
,
NXdlPerWave
,
K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
MPerXdl
,
ABlockTransferThreadClusterArrangeOrder
,
NPerXdl
,
ABlockTransferSrcAccessOrder
,
MXdlPerWave
,
ABlockTransferSrcVectorDim
,
NXdlPerWave
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferDstScalarPerVector_K1
,
ABlockTransferThreadClusterArrangeOrder
,
false
,
// AThreadTransferSrcResetCoordinateAfterRun,
ABlockTransferSrcAccessOrder
,
ABlockLdsAddExtraM
,
ABlockTransferSrcVectorDim
,
ABlockLdsM1PerBlock
,
ABlockTransferSrcScalarPerVector
,
ABlockLdsM0PerBlock
,
ABlockTransferDstScalarPerVector_K1
,
ABlockLdsM1Padding
,
false
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
ABlockLdsAddExtraM
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferThreadClusterLengths_K0_N_K1
,
BBlockTransferSrcAccessOrder
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcVectorDim
,
BBlockTransferDstScalarPerVector_K1
,
BBlockTransferSrcScalarPerVector
,
false
,
// BThreadTransferSrcResetCoordinateAfterRun,
BBlockTransferDstScalarPerVector_K1
,
BBlockLdsAddExtraN
,
false
,
BBlockLdsN1PerBlock
,
BBlockLdsAddExtraN
,
BBlockLdsN0PerBlock
,
CShuffleMXdlPerWavePerShuffle
,
BBlockLdsN1Padding
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleMXdlPerWavePerShuffle
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleNXdlPerWavePerShuffle
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
CBlockTransferScalarPerVector_NWaveNPerXdl
,
BlkGemmPipeSched
,
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
BlkGemmPipelineVer
,
true
,
ComputeTypeA
,
true
,
ComputeTypeB
>
;
1
,
PipelineVersion
::
v1
,
ComputeTypeA
,
ComputeTypeB
>
;
static
constexpr
index_t
ClusterLengthMPerBlock
=
static
constexpr
index_t
ClusterLengthMPerBlock
=
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
::
At
(
1
);
...
@@ -347,8 +414,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -347,8 +414,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
Block2TileMapElementwise
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
using
GridwiseElementwise
=
using
GridwiseElementwise
=
GridwiseElementwise
<
Tuple
<
CGridDesc_M_N
>
,
GridwiseElementwise
<
Tuple
<
C
Elementwise
GridDesc_M_N
>
,
Tuple
<
CGridDesc_M_N
>
,
Tuple
<
C
Elementwise
GridDesc_M_N
>
,
Tuple
<
const
AccDataType
*>
,
Tuple
<
const
AccDataType
*>
,
Tuple
<
EDataType
*>
,
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
...
@@ -366,10 +433,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -366,10 +433,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
// Argument
// Argument
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
using
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{},
1
,
1
));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
CGridDesc_M_N
{},
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -395,11 +460,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -395,11 +460,10 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
:
p_a_grid_
{
p_out_grid
},
:
p_a_grid_
{
p_out_grid
},
p_b_grid_
{
p_in_grid
},
p_b_grid_
{
p_in_grid
},
p_e_grid_
{
p_wei_grid
},
p_e_grid_
{
p_wei_grid
},
a_grid_desc_
kbatch_
k0_m_k1_
{},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_
kbatch_
k0_n_k1_
{},
b_grid_desc_k0_n_k1_
{},
ce_grid_desc_m_n_
{},
ce_grid_desc_m_n_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_ctile_map_
{},
compute_ptr_offset_of_batch_
{},
compute_ptr_offset_of_batch_
{},
M01_
{
M01
},
M01_
{
M01
},
N01_
{
N01
},
N01_
{
N01
},
...
@@ -430,7 +494,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -430,7 +494,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
begin
(
output_spatial_lengths_
));
begin
(
output_spatial_lengths_
));
const
auto
descs
=
const
auto
descs
=
conv_to_gemm_transformer
conv_to_gemm_transformer
_v2
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
Conv_N_
,
Conv_N_
,
Conv_K_
,
Conv_K_
,
...
@@ -447,15 +511,34 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -447,15 +511,34 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
input_right_pads
,
input_right_pads
,
k_batch_
);
k_batch_
);
a_grid_desc_kbatch_k0_m_k1_
=
descs
[
I0
];
a_grid_desc_k0_m_k1_
=
descs
[
I0
];
b_grid_desc_kbatch_k0_n_k1_
=
descs
[
I1
];
b_grid_desc_k0_n_k1_
=
descs
[
I1
];
ce_grid_desc_m_n_
=
descs
[
I2
];
ce_grid_desc_m_n_
=
descs
[
I2
];
ce_elementwise_grid_desc_m_n_
=
conv_to_gemm_transformer_v1
.
template
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>(
Conv_N_
,
Conv_K_
,
Conv_C_
,
input_spatial_lengths_
,
filter_spatial_lengths_
,
output_spatial_lengths_
,
b_g_n_c_wis_strides
,
e_g_k_c_xs_strides
,
a_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
k_batch_
)[
I2
];
block_2_ctile_map_
=
GridwiseGemm
::
MakeCBlockClusterAdaptor
(
ce_grid_desc_m_n_
,
M01
,
N01
,
k_batch_
);
elementwise_block_2_ctile_map_
=
Block2TileMapElementwise
{
elementwise_block_2_ctile_map_
=
Block2TileMapElementwise
{
ce_grid_desc_m_n_
.
GetLength
(
I0
),
ce_grid_desc_m_n_
.
GetLength
(
I1
)};
ce_grid_desc_m_n_
.
GetLength
(
I0
),
ce_grid_desc_m_n_
.
GetLength
(
I1
)};
const
index_t
GemmM
=
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_k_wos_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_n_c_wis_strides
[
0
];
...
@@ -465,16 +548,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -465,16 +548,11 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
end
(
filter_spatial_lengths_
),
end
(
filter_spatial_lengths_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_kbatch_k0_m_k1_
,
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
b_grid_desc_kbatch_k0_n_k1_
,
ce_grid_desc_m_n_
,
ce_grid_desc_m_n_
,
GridwiseGemm
::
CalculateMBlock
(
GemmM
),
block_2_ctile_map_
))
GridwiseGemm
::
CalculateNBlock
(
GemmN
));
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
(
ce_grid_desc_m_n_
);
}
}
}
std
::
size_t
GetWorkspaceSizeBytes
()
const
std
::
size_t
GetWorkspaceSizeBytes
()
const
...
@@ -486,12 +564,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -486,12 +564,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_
kbatch_
k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_
kbatch_
k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
ce_grid_desc_m_n_
;
CGridDesc_M_N
ce_grid_desc_m_n_
;
CElementwiseGridDesc_M_N
ce_elementwise_grid_desc_m_n_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
Block2CTileMap
block_2_ctile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_
;
Block2TileMapElementwise
elementwise_block_2_ctile_map_
;
// for computing batch offset
// for computing batch offset
...
@@ -525,96 +603,676 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -525,96 +603,676 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
void
ShowInfo
(
const
Argument
&
arg
)
void
ShowInfo
(
const
Argument
&
arg
)
{
{
std
::
cout
<<
"arg.a_grid_desc_kbatch_k0_m_k1_{"
std
::
cout
<<
"arg.a_grid_desc_k0_m_k1_{"
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I0
)
<<
", "
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
std
::
cout
<<
"arg.b_grid_desc_kbatch_k0_n_k1_{"
<<
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I2
)
<<
", "
<<
arg
.
b_grid_desc_kbatch_k0_n_k1_
.
GetLength
(
I3
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.ce_grid_desc_m_n_{"
<<
arg
.
ce_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.ce_grid_desc_m_n_{"
<<
arg
.
ce_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
ce_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
ce_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
GemmV3
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
const
index_t
GemmM
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
);
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
const
index_t
GemmN
=
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
arg
.
ce_grid_desc_m_n_
,
const
index_t
GemmK
=
arg
.
block_2_ctile_map_
))
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
AccDataType
*
p_c_grid
=
type_convert
<
AccDataType
*>
(
arg
.
p_workspace_
);
// nullptr for output, will be set after workspace set
typename
GridwiseGemm
::
Argument
gemm_arg
{
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
p_c_grid
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
index_t
gdx
,
gdy
,
gdz
;
std
::
tie
(
gdx
,
gdy
,
gdz
)
=
GridwiseGemm
::
CalculateGridSize
(
gemm_arg
.
M
,
gemm_arg
.
N
,
gemm_arg
.
KBatch
,
arg
.
Conv_G_
/
NumBatchToMerge
);
float
ave_time
=
0
;
index_t
k_grain
=
gemm_arg
.
KBatch
*
KPerBlock
;
index_t
K_split
=
(
gemm_arg
.
K
+
k_grain
-
1
)
/
k_grain
*
(
KPerBlock
);
const
bool
has_main_k_block_loop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K_split
);
const
auto
num_k_per_block
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
Number
<
0
>
{})
/
gemm_arg
.
KBatch
;
const
auto
clear_workspace
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
gemm_arg
.
p_c_grid
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
};
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
if
(
stream_config
.
flush_cache
)
{
typename
GridwiseGemm
::
Argument
gemm_arg_
=
gemm_arg
;
ck
::
utility
::
RotatingMemWrapper
<
typename
GridwiseGemm
::
Argument
>
rotating_mem
(
gemm_arg_
,
stream_config
.
rotating_count
,
gemm_arg_
.
M
*
gemm_arg_
.
K
*
sizeof
(
ADataType
),
gemm_arg_
.
K
*
gemm_arg_
.
N
*
sizeof
(
BDataType
));
rotating_mem
.
Print
();
auto
run_flush_cache
=
[
&
]()
{
// flush icache
ck
::
utility
::
flush_icache
();
// rotating mem
rotating_mem
.
Next
();
clear_workspace
();
};
ave_time
=
ck
::
utility
::
launch_and_time_kernel_with_preprocess
<
false
>
(
stream_config
,
run_flush_cache
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
gemm_arg_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
num_k_per_block
);
}
else
{
ave_time
=
launch_and_time_kernel_with_preprocess
(
stream_config
,
clear_workspace
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
gemm_arg
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
num_k_per_block
);
}
};
constexpr
index_t
minimum_occupancy
=
BlkGemmPipeSched
==
BlockGemmPipelineScheduler
::
Intrawave
?
1
:
2
;
if
(
has_main_k_block_loop
)
{
// Tail number always full
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
||
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v3
)
{
if
(
gemm_arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
// Tail number could be One to Seven
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v2
)
{
if
(
gemm_arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
One
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
One
>
;
Run
(
kernel
);
}
else
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Full
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Full
>
;
Run
(
kernel
);
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
2
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Two
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Two
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
3
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Three
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Three
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
4
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Four
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Four
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
5
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Five
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Five
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
6
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Six
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Six
>
;
Run
(
kernel
);
}
}
if
constexpr
(
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
>
7
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Seven
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Seven
>
;
Run
(
kernel
);
}
}
}
}
// Tail number could be Odd or Even
else
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v4
)
{
if
(
gemm_arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
else
{
if
(
gemm_arg
.
KBatch
>
1
)
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
else
{
if
(
GridwiseGemm
::
CalculateKBlockLoopTailNum
(
K_split
)
==
TailNumber
::
Odd
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Odd
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
true
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
,
TailNumber
::
Even
>
;
Run
(
kernel
);
}
}
}
}
else
{
{
throw
std
::
runtime_error
(
// Tail number always 1
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"
);
if
constexpr
(
BlkGemmPipelineVer
==
BlockGemmPipelineVersion
::
v1
)
{
if
(
gemm_arg
.
KBatch
>
1
)
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
false
,
InMemoryDataOperationEnum
::
AtomicAdd
,
minimum_occupancy
>
;
Run
(
kernel
);
}
else
{
const
auto
kernel
=
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
<
GridwiseGemm
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
NumBatchToMerge
,
false
,
InMemoryDataOperationEnum
::
Set
,
minimum_occupancy
>
;
Run
(
kernel
);
}
}
}
}
const
auto
K0
=
arg
.
a_grid_desc_kbatch_k0_m_k1_
.
GetLength
(
I1
);
return
ave_time
;
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
}
auto
launch_gemm_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
AccDataType
*
p_c_grid
=
type_convert
<
AccDataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
constexpr
bool
has_main_loop
=
has_main_k_block_loop
.
value
;
auto
preprocess
=
[
&
]()
{
hip_check_error
(
hipMemsetAsync
(
p_c_grid
,
0
,
arg
.
GetWorkspaceSizeBytes
(),
stream_config
.
stream_id_
));
};
const
auto
kernel
=
kernel_batched_gemm_xdlops_bwd_weight
<
GridwiseGemm
,
ADataType
,
BDataType
,
AccDataType
,
OutElementwiseOperation
,
InElementwiseOperation
,
element_wise
::
PassThrough
,
remove_reference_t
<
DeviceOp
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceOp
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceOp
::
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
DeviceOp
::
Block2CTileMap
>
,
ComputePtrOffsetOfStridedBatch
<
I1
,
I1
,
I0
>
,
has_main_loop
>
;
return
launch_and_time_kernel_with_preprocess
(
stream_config
,
preprocess
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
p_c_grid
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
element_wise
::
PassThrough
{},
arg
.
Conv_G_
,
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
block_2_ctile_map_
,
arg
.
compute_ptr_offset_of_batch_
);
};
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
auto
launch_elementwise_kernel
=
[
&
]()
{
auto
launch_elementwise_kernel
=
[
&
]()
{
const
AccDataType
*
p_c_grid
=
type_convert
<
const
AccDataType
*>
(
arg
.
p_workspace_
);
const
AccDataType
*
p_c_grid
=
type_convert
<
const
AccDataType
*>
(
arg
.
p_workspace_
);
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
elementwise_block_2_ctile_map_
.
CalculateGridSize
(
arg
.
elementwise_
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
ce_
grid_desc_m_n_
)
*
arg
.
ce_
elementwise_grid_desc_m_n_
)
*
arg
.
Conv_G_
;
arg
.
Conv_G_
;
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
std
::
array
<
index_t
,
I1
>
in_out_batch_strides
=
{
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
};
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideC_
};
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
,
const
auto
kernel
=
kernel_batched_elementwise
<
GridwiseElementwise
,
ck
::
Tuple
<
CGridDesc_M_N
>
,
ck
::
Tuple
<
C
Elementwise
GridDesc_M_N
>
,
ck
::
Tuple
<
CGridDesc_M_N
>
,
ck
::
Tuple
<
C
Elementwise
GridDesc_M_N
>
,
ck
::
Tuple
<
const
AccDataType
*>
,
ck
::
Tuple
<
const
AccDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
ck
::
Tuple
<
EDataType
*>
,
Block2TileMapElementwise
,
Block2TileMapElementwise
,
...
@@ -627,8 +1285,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -627,8 +1285,8 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
0
,
0
,
make_tuple
(
arg
.
ce_grid_desc_m_n_
),
make_tuple
(
arg
.
ce_
elementwise_
grid_desc_m_n_
),
make_tuple
(
arg
.
ce_grid_desc_m_n_
),
make_tuple
(
arg
.
ce_
elementwise_
grid_desc_m_n_
),
make_tuple
(
p_c_grid
),
make_tuple
(
p_c_grid
),
make_tuple
(
arg
.
p_e_grid_
),
make_tuple
(
arg
.
p_e_grid_
),
arg
.
elementwise_block_2_ctile_map_
,
arg
.
elementwise_block_2_ctile_map_
,
...
@@ -638,16 +1296,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -638,16 +1296,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
in_out_batch_strides
);
in_out_batch_strides
);
};
};
float
avg_time
=
0
;
float
avg_time
=
RunGemmV3
(
arg
,
stream_config
);
if
(
has_main_k0_block_loop
)
{
avg_time
=
launch_gemm_kernel
(
integral_constant
<
bool
,
true
>
{});
}
else
{
avg_time
=
launch_gemm_kernel
(
integral_constant
<
bool
,
false
>
{});
}
avg_time
+=
launch_elementwise_kernel
();
avg_time
+=
launch_elementwise_kernel
();
return
avg_time
;
return
avg_time
;
}
}
...
@@ -667,6 +1316,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -667,6 +1316,23 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
const
index_t
GemmM
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
);
const
index_t
GemmN
=
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
);
const
index_t
GemmK
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
typename
GridwiseGemm
::
Argument
gemm_arg
{
nullptr
,
nullptr
,
nullptr
,
GemmM
,
GemmN
,
GemmK
,
I0
,
I0
,
I0
,
arg
.
k_batch_
};
const
auto
num_k_loop
=
gemm_arg
.
AK0
/
(
KPerBlock
/
K1
);
if
constexpr
(
BlkGemmPipelineVer
!=
BlockGemmPipelineVersion
::
v1
)
{
if
(
num_k_loop
<=
GridwiseGemm
::
BlockwiseGemmPipe
::
PrefetchStages
)
{
return
false
;
}
}
// Check this here, it allows to use other instances from factory even
// Check this here, it allows to use other instances from factory even
// if workspace is not allocated
// if workspace is not allocated
if
(
!
arg
.
p_workspace_
)
if
(
!
arg
.
p_workspace_
)
...
@@ -723,10 +1389,38 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -723,10 +1389,38 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
}
}
}
}
if
constexpr
(
NumBatchToMerge
>
1
)
{
// support only if whole M and N can be proccessed on one block
if
(
!
(
GemmM
<=
MPerBlock
&&
GemmN
<=
NPerBlock
))
{
return
false
;
}
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
Conv_K_
==
1
))
{
return
false
;
}
if
(
arg
.
Conv_G_
%
NumBatchToMerge
!=
0
)
{
return
false
;
}
}
if
(
!
(
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
))
{
if
(
!
(
arg
.
Conv_K_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideA_
==
1
))
{
return
false
;
}
if
(
!
(
arg
.
Conv_C_
==
1
&&
arg
.
compute_ptr_offset_of_batch_
.
BatchStrideB_
==
1
))
{
return
false
;
}
}
// vector load A/B matrix from global memory
// vector load A/B matrix from global memory
if
(
!
(
ABlockTransferSrcVectorDim
==
2
&&
BBlockTransferSrcVectorDim
==
2
&&
if
(
!
(
ABlockTransferSrcVectorDim
==
1
&&
BBlockTransferSrcVectorDim
==
1
))
arg
.
Conv_K_
%
ABlockTransferSrcScalarPerVector
==
0
&&
arg
.
Conv_C_
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
return
false
;
return
false
;
}
}
...
@@ -737,11 +1431,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -737,11 +1431,7 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
return
false
;
return
false
;
}
}
// Gridwise GEMM size
return
true
;
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_kbatch_k0_m_k1_
,
arg
.
b_grid_desc_kbatch_k0_n_k1_
,
arg
.
ce_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
}
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
@@ -840,13 +1530,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -840,13 +1530,24 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
std
::
map
<
BlockGemmPipelineScheduler
,
std
::
string
>
BlkGemmPipelineSchedulerToString
{
{
BlockGemmPipelineScheduler
::
Intrawave
,
"Intrawave"
},
{
BlockGemmPipelineScheduler
::
Interwave
,
"Interwave"
}};
std
::
map
<
BlockGemmPipelineVersion
,
std
::
string
>
BlkGemmPipelineVersionToString
{
{
BlockGemmPipelineVersion
::
v1
,
"v1"
},
{
BlockGemmPipelineVersion
::
v2
,
"v2"
},
{
BlockGemmPipelineVersion
::
v3
,
"v3"
},
{
BlockGemmPipelineVersion
::
v4
,
"v4"
},
{
BlockGemmPipelineVersion
::
v5
,
"v5"
}};
// clang-format off
// clang-format off
str
<<
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"
str
<<
"DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle"
<<
"<"
<<
"<"
<<
BlockSize
<<
", "
<<
BlockSize
<<
", "
<<
MPerBlock
<<
", "
<<
MPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K
0
PerBlock
<<
", "
<<
KPerBlock
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
getConvBackwardWeightSpecializationString
(
ConvBackwardWeightSpecialization
)
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MXdlPerWave
<<
", "
<<
MXdlPerWave
<<
", "
...
@@ -857,7 +1558,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
...
@@ -857,7 +1558,12 @@ struct DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
BBlockTransferDstScalarPerVector_K1
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleMXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CShuffleNXdlPerWavePerShuffle
<<
", "
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
CBlockTransferScalarPerVector_NWaveNPerXdl
<<
", "
<<
"BlkGemmPipelineScheduler: "
<<
BlkGemmPipelineSchedulerToString
[
BlkGemmPipeSched
]
<<
", "
<<
"BlkGemmPipelineVersion: "
<<
BlkGemmPipelineVersionToString
[
BlkGemmPipelineVer
]
<<
", "
<<
NumBatchToMerge
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
e4112de7
...
@@ -45,8 +45,7 @@ __global__ void
...
@@ -45,8 +45,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
KBatch
=
1
;
const
index_t
KBatch
=
1
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
e4112de7
...
@@ -50,8 +50,7 @@ __global__ void
...
@@ -50,8 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -80,7 +79,7 @@ __global__ void
...
@@ -80,7 +79,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Assume B is Col-Major
// Assume B is Col-Major
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
namespace
ck
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1Value
,
index_t
BK1Value
,
index_t
MPerXdl
,
index_t
NPerXdl
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
AThreadTransferSrcResetCoordinateAfterRun
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BThreadTransferSrcResetCoordinateAfterRun
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
BlockGemmPipelineScheduler
BlkGemmPipeSched
=
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
BlkGemmPipelineVer
=
BlockGemmPipelineVersion
::
v4
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
GridwiseGemm_xdl_cshuffle_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0Number
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0Number
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1Number
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1Number
=
Number
<
BK1Value
>
{};
static
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1Number
,
BK1Number
),
MfmaSelector
<
ComputeTypeA
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
Batch
)
{
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
KBatch
,
Batch
);
}
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
math
::
integer_least_multiple
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
math
::
integer_divide_ceil
(
K
,
KPerBlock
)
*
KPerBlock
;
}
__host__
static
auto
CalculateAK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
AK1Value
);
}
__host__
static
auto
CalculateBK0Padded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
(
KPerBlock
/
BK1Value
);
}
__host__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
{
auto
K_t
=
K_Batch
*
KPerBlock
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KPerBlock
;
}
__host__
static
auto
CalculateKRead
(
index_t
K
,
index_t
K_Batch
=
1
)
{
constexpr
auto
KReadVec
=
math
::
lcm
(
AK1Number
,
BK1Number
);
auto
K_t
=
K_Batch
*
KReadVec
;
return
(
K
+
K_t
-
1
)
/
K_t
*
KReadVec
;
}
__host__
static
auto
CalculateMBlock
(
index_t
M
)
{
return
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
}
__host__
static
auto
CalculateNBlock
(
index_t
N
)
{
return
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
}
template
<
index_t
MNXdlPerWave
,
index_t
MNWaves
,
index_t
MNPerXdl
,
typename
TileDesc_K0_MN_K1
>
__host__
__device__
static
constexpr
auto
MakeGemmMmaTileDescriptor
(
const
TileDesc_K0_MN_K1
&
)
{
constexpr
index_t
K0
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
0
>
{});
constexpr
index_t
K1
=
TileDesc_K0_MN_K1
{}.
GetLength
(
Number
<
2
>
{});
return
transform_tensor_descriptor
(
TileDesc_K0_MN_K1
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
K0
>
{},
Number
<
K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MNXdlPerWave
>
{},
Number
<
MNWaves
>
{},
Number
<
MNPerXdl
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
template
<
typename
ABlockDesc_AK0_M_AK1
>
__host__
__device__
static
constexpr
auto
MakeAMmaTileDescriptor_M0_M1_M2_K
(
const
ABlockDesc_AK0_M_AK1
&
)
{
constexpr
index_t
MWaves
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
return
MakeGemmMmaTileDescriptor
<
MXdlPerWave
,
MWaves
,
MPerXdl
>
(
ABlockDesc_AK0_M_AK1
{});
}
template
<
typename
BBlockDesc_BK0_N_BK1
>
__host__
__device__
static
constexpr
auto
MakeBMmaTileDescriptor_N0_N1_N2_K
(
const
BBlockDesc_BK0_N_BK1
&
)
{
constexpr
index_t
NWaves
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
return
MakeGemmMmaTileDescriptor
<
NXdlPerWave
,
NWaves
,
NPerXdl
>
(
BBlockDesc_BK0_N_BK1
{});
}
struct
Problem
{
__host__
Problem
(
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
KBatch_
)
:
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
KBatch
{
KBatch_
},
MPadded
{
CalculateMPadded
(
M_
)},
NPadded
{
CalculateNPadded
(
N_
)},
KRead
{
CalculateKRead
(
K_
,
KBatch_
)},
KPadded
{
CalculateKPadded
(
K_
,
KBatch_
)},
AK0
{
CalculateAK0Padded
(
K_
,
KBatch_
)},
BK0
{
CalculateBK0Padded
(
K_
,
KBatch_
)},
MBlock
{
CalculateMBlock
(
M_
)},
NBlock
{
CalculateNBlock
(
N_
)}
{
}
__host__
void
Print
()
const
{
std
::
cout
<<
"problem {"
<<
"M:"
<<
M
<<
", "
<<
"N:"
<<
N
<<
", "
<<
"K:"
<<
K
<<
", "
<<
"SA:"
<<
StrideA
<<
", "
<<
"SB:"
<<
StrideB
<<
", "
<<
"SC:"
<<
StrideC
<<
", "
<<
"MP:"
<<
MPadded
<<
", "
<<
"NP:"
<<
NPadded
<<
", "
<<
"KRead:"
<<
KRead
<<
", "
<<
"KP:"
<<
KPadded
<<
", "
<<
"AK0:"
<<
AK0
<<
", "
<<
"BK0:"
<<
BK0
<<
", "
<<
"MBlock: "
<<
MBlock
<<
", "
<<
"NBlock: "
<<
NBlock
<<
"}"
<<
std
::
endl
;
}
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
KBatch
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KRead
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
index_t
MBlock
;
index_t
NBlock
;
};
// Argument
struct
Argument
:
public
tensor_operation
::
device
::
BaseArgument
,
public
Problem
{
__host__
Argument
(
const
ADataType
*
p_a_grid_
,
const
BDataType
*
p_b_grid_
,
CDataType
*
p_c_grid_
,
index_t
M_
,
index_t
N_
,
index_t
K_
,
index_t
StrideA_
,
index_t
StrideB_
,
index_t
StrideC_
,
index_t
k_batch_
)
:
Problem
{
M_
,
N_
,
K_
,
StrideA_
,
StrideB_
,
StrideC_
,
k_batch_
},
p_a_grid
{
p_a_grid_
},
p_b_grid
{
p_b_grid_
},
p_c_grid
{
p_c_grid_
}
{
}
const
ADataType
*
p_a_grid
;
const
BDataType
*
p_b_grid
;
CDataType
*
p_c_grid
;
};
__device__
static
constexpr
auto
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()
{
// A matrix in LDS memory, dst of blockwise copy
if
constexpr
(
ABlockLdsExtraM
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
,
Number
<
MPerBlock
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
+
ABlockLdsExtraM
>
{},
I1
));
}
// xor tensor transformation request more unnecessary vgpr usage, would cause register spill
// in some cases.
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
AK0Number
*
Number
<
MLdsLayer
>
{},
Number
<
MPerBlock
/
MLdsLayer
>
{},
AK1Number
),
make_tuple
(
AK1Number
,
Number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_mldslayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0Number
,
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_mldslayer_m_ak1
,
make_tuple
(
make_pass_through_transform
(
AK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I1
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
ABlockTransferThreadClusterLengths_AK0_M_AK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
AK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
MPerXdl
;
constexpr
auto
K0PerThreadRead
=
AK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
AK1Number
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
AK1Number
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=n0
constexpr
auto
mpair
=
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
AK1Number
*
MPerXdl
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{},
Number
<
mpair
>
{},
AK1Number
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
M0
/
mpair
>
{},
Number
<
mpair
>
{},
Number
<
M1
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
a_lds_block_desc_ak0_m_ak1
;
}
}
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
,
Number
<
NPerBlock
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
+
BBlockLdsExtraN
>
{},
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
BK0Number
*
Number
<
NLdsLayer
>
{},
Number
<
NPerBlock
/
NLdsLayer
>
{},
BK1Number
),
make_tuple
(
BK1Number
,
Number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0Number
,
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
Number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{},
Sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
make_pass_through_transform
(
BK0Number
),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
BBlockTransferThreadClusterLengths_BK0_N_BK1
{}.
At
(
I0
);
constexpr
auto
K0PerThreadWrite
=
BK0Number
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0Number
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1Number
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
BK1Number
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
BK1Number
*
NPerXdl
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
K0PerThreadWrite
>
{},
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{},
Number
<
npair
>
{},
BK1Number
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_with_modulo_transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
,
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
kfold
>
{},
Number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
0
,
3
>
{},
Sequence
<
4
,
5
>
{},
Sequence
<
6
>
{},
Sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
KThreadReadPerm
>
{},
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
Number
<
kfold
>
{},
Number
<
K0PerThreadWrite
>
{})),
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
N0
/
npair
>
{},
Number
<
npair
>
{},
Number
<
N1
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
0
,
1
,
4
,
2
>
{},
Sequence
<
5
,
6
,
3
>
{},
Sequence
<
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
return
b_lds_block_desc_bk0_n_bk1
;
}
}
__device__
static
constexpr
auto
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
()
{
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
>
{},
I1
,
Number
<
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
{}));
return
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
;
}
using
BlockwiseGemmPipe
=
remove_cvref_t
<
decltype
(
BlockGemmPipeline_Selector
<
BlkGemmPipelineVer
,
BlkGemmPipeSched
,
BlockSize
,
ADataType
,
BDataType
,
ComputeTypeA
,
AccDataType
,
decltype
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
()),
decltype
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()),
decltype
(
MakeAMmaTileDescriptor_M0_M1_M2_K
(
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
())),
decltype
(
MakeBMmaTileDescriptor_N0_N1_N2_K
(
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
())),
ABlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
MPerXdl
,
NPerXdl
,
MXdlPerWave
,
NXdlPerWave
,
KPack
>
())
>
;
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
(),
max_lds_align
);
// LDS allocation for C shuffle in LDS
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
constexpr
auto
c_block_size
=
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
();
return
math
::
max
((
a_block_space_size_aligned
*
sizeof
(
ADataType
)
+
b_block_space_size_aligned
*
sizeof
(
BDataType
)),
c_block_size
*
sizeof
(
CShuffleDataType
));
}
__host__
static
constexpr
bool
CalculateHasMainKBlockLoop
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockHasHotloop
(
num_loop
);
}
__host__
static
constexpr
TailNumber
CalculateKBlockLoopTailNum
(
index_t
K
)
{
const
index_t
num_loop
=
K
/
KPerBlock
;
return
BlockwiseGemmPipe
::
BlockLoopTailNum
(
num_loop
);
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
,
index_t
MBlock
,
index_t
NBlock
)
{
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MBlock
,
Number
<
MPerBlock
>
{})),
make_unmerge_transform
(
make_tuple
(
NBlock
,
Number
<
NPerBlock
>
{}))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
,
3
>
{}));
return
c_grid_desc_mblock_mperblock_nblock_nperblock
;
}
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using
Block2CTileMap
=
BlockToCTileMap_Grouped_M00_N0_M01Adapt
<
8
,
MPerBlock
,
NPerBlock
>
;
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared
,
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
index_t
k_id
=
0
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
)));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ADataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k_id
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
k_id
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
// Cast after lds
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
KPerBlock
*
problem
.
KBatch
));
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_buf
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
template
<
typename
AGridDesc_AK0_M_K1
,
typename
BGridDesc_BK0_N_K1
,
typename
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
bool
HasMainKBlockLoop
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
TailNumber
TailNum
=
TailNumber
::
Odd
>
__device__
static
void
Run_2Lds
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
void
*
p_shared_0
,
void
*
p_shared_1
,
const
Problem
&
problem
,
const
AGridDesc_AK0_M_K1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
&
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
index_t
k_id
=
0
)
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_ak0_m_ak1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
const
AElementwiseOperation
a_element_op
{};
const
BElementwiseOperation
b_element_op
{};
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Block2CTileMap
{
problem
.
M
,
problem
.
N
,
4
};
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
static_cast
<
index_t
>
(
blockIdx
.
x
)));
if
(
!
block_2_ctile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
const
index_t
block_m_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
block_n_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]);
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_m_id
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_n_id
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1Number
,
BK1Number
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// A matrix blockwise copy
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0Number
,
MPerBlock
,
AK1Number
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ADataType
,
ADataType
,
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
ABlockTransferSrcVectorDim
,
2
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
k_id
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_block_desc_ak0_m_ak1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// B matrix blockwise copy
auto
b_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0Number
,
NPerBlock
,
BK1Number
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
BBlockTransferSrcVectorDim
,
2
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
1
,
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
,
BlockwiseGemmPipe
::
GlobalBufferNum
>
(
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
k_id
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc_bk0_n_bk1
,
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_0
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_ping
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared_0
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
ADataType
*>
(
p_shared_1
),
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
());
auto
b_block_buf_pong
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
BDataType
*>
(
p_shared_1
)
+
a_block_space_size_aligned
*
sizeof
(
ADataType
)
/
sizeof
(
BDataType
),
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
auto
a_block_bufs
=
make_tuple
(
a_block_buf_ping
,
a_block_buf_pong
);
auto
b_block_bufs
=
make_tuple
(
b_block_buf_ping
,
b_block_buf_pong
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1Number
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1Number
,
0
,
0
);
// Blockwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
BlockwiseGemmPipe
>
);
auto
blockwise_gemm_pipeline
=
BlockwiseGemmPipe
{};
auto
c_thread_buf
=
blockwise_gemm_pipeline
.
GetCThreadBuffer
();
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))
/
(
KPerBlock
*
problem
.
KBatch
));
blockwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
TailNum
>(
a_grid_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_bufs
,
a_block_slice_copy_step
,
b_grid_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
b_blockwise_copy
,
b_grid_buf
,
b_block_bufs
,
b_block_slice_copy_step
,
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
NXdlPerWave
%
CShuffleNXdlPerWavePerShuffle
==
0
,
"wrong!"
);
constexpr
index_t
MWave
=
MPerBlock
/
(
MXdlPerWave
*
MPerXdl
);
constexpr
index_t
NWave
=
NPerBlock
/
(
NXdlPerWave
*
NPerXdl
);
// TODO: hacky, fix it!
constexpr
auto
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
blockwise_gemm_pipeline
.
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
// TODO: hacky, fix it!
// c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
=
blockwise_gemm_pipeline
.
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
();
constexpr
auto
M0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I0
);
constexpr
auto
N0
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I1
);
constexpr
auto
M1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I2
);
constexpr
auto
N1
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I3
);
constexpr
auto
M2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I4
);
constexpr
auto
M3
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I5
);
constexpr
auto
M4
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I6
);
constexpr
auto
N2
=
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp
.
GetLength
(
I7
);
constexpr
auto
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
=
GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
();
auto
c_shuffle_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
CShuffleDataType
*>
(
p_shared_0
),
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
.
GetElementSpaceSize
());
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
transform_tensor_descriptor
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_tuple
(
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleMXdlPerWavePerShuffle
>
{},
// M0 (MXdlPerWave) per shuffle
M1
,
// M1 = MWave
M2
,
// M2 * M3 * M4 = MPerXdl
M3
,
M4
)),
make_freeze_transform
(
I0
),
make_unmerge_transform
(
make_tuple
(
Number
<
CShuffleNXdlPerWavePerShuffle
>
{},
// N0 (NXdlPerWave) per shuffle
N1
,
// N1 = NWave
N2
))),
// N2 = NPerXdl
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
,
2
,
4
,
5
,
6
>
{},
Sequence
<>
{},
Sequence
<
1
,
3
,
7
>
{}));
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm_pipeline
.
CalculateCThreadOriginDataIndex
(
I0
,
I0
,
I0
,
I0
);
const
index_t
m_thread_data_on_block
=
c_thread_mtx_on_block
[
I0
];
const
index_t
n_thread_data_on_block
=
c_thread_mtx_on_block
[
I1
];
const
auto
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
M0
,
M1
,
M2
,
M3
,
M4
))),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
,
4
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
m_thread_data_on_block_idx
=
m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
m_thread_data_on_block
));
const
auto
n_thread_data_on_block_to_n0_n1_n2_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
N0
,
N1
,
N2
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
n_thread_data_on_block_idx
=
n_thread_data_on_block_to_n0_n1_n2_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
n_thread_data_on_block
));
// shuffle: threadwise copy C from VGPR to LDS
auto
c_thread_copy_vgpr_to_lds
=
ThreadwiseTensorSliceTransfer_v1r3
<
AccDataType
,
CShuffleDataType
,
decltype
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
decltype
(
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
I1
,
I1
,
M2
,
I1
,
M4
,
I1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
7
,
1
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
{
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
make_multi_index
(
0
,
0
,
m_thread_data_on_block_idx
[
I1
],
n_thread_data_on_block_idx
[
I1
],
m_thread_data_on_block_idx
[
I2
],
m_thread_data_on_block_idx
[
I3
],
m_thread_data_on_block_idx
[
I4
],
n_thread_data_on_block_idx
[
I2
]),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{}};
// shuffle: blockwise copy C from LDS to global
auto
c_shuffle_block_copy_lds_to_global
=
ThreadGroupTensorSliceTransfer_v6r1
<
ThisThreadBlock
,
// ThreadGroup
CElementwiseOperation
,
// ElementwiseOperation,
CGlobalMemoryDataOperation
,
// DstInMemOp,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>
,
// BlockSliceLengths,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
Sequence
<
0
,
1
,
2
,
3
>
,
// typename ThreadClusterArrangeOrder,
CShuffleDataType
,
// typename SrcData,
CDataType
,
// typename DstData,
decltype
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
),
decltype
(
c_grid_desc_mblock_mperblock_nblock_nperblock
),
Sequence
<
0
,
1
,
2
,
3
>
,
// typename DimAccessOrder,
3
,
// index_t VectorDim,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
// index_t ScalarPerVector,
true
,
// bool ThreadTransferSrcResetCoordinateAfterRun,
false
>
// bool ThreadTransferDstResetCoordinateAfterRun>
{
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
0
,
0
,
0
,
0
),
c_grid_desc_mblock_mperblock_nblock_nperblock
,
make_multi_index
(
block_m_id
,
0
,
block_n_id
,
0
),
c_element_op
};
// space filling curve for threadwise C in VGPR
constexpr
auto
sfc_c_vgpr
=
SpaceFillingCurve
<
Sequence
<
MXdlPerWave
,
NXdlPerWave
,
1
,
1
,
M2
,
1
,
M4
,
1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
>
,
Sequence
<
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
1
,
1
,
M2
,
1
,
M4
,
1
>>
{};
// space filling curve for shuffled blockwise C in global mem
constexpr
auto
sfc_c_global
=
SpaceFillingCurve
<
Sequence
<
1
,
MPerBlock
,
1
,
NPerBlock
>
,
Sequence
<
0
,
2
,
1
,
3
>
,
Sequence
<
1
,
CShuffleMXdlPerWavePerShuffle
*
MWave
*
MPerXdl
,
1
,
CShuffleNXdlPerWavePerShuffle
*
NWave
*
NPerXdl
>>
{};
constexpr
index_t
num_access
=
sfc_c_vgpr
.
GetNumOfAccess
();
static_assert
(
num_access
==
sfc_c_global
.
GetNumOfAccess
(),
"wrong!"
);
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
access_id
)
{
// make sure it's safe to write to LDS
block_sync_lds
();
// each thread write its data from VGPR to LDS
c_thread_copy_vgpr_to_lds
.
Run
(
c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
sfc_c_vgpr
.
GetIndexTupleOfNumber
(
access_id
),
c_thread_buf
,
c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c_shuffle_block_buf
);
// make sure it's safe to read from LDS
block_sync_lds
();
// each block copy its data from LDS to global
c_shuffle_block_copy_lds_to_global
.
Run
(
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock
,
c_shuffle_block_buf
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_buf
);
if
constexpr
(
access_id
<
num_access
-
1
)
{
constexpr
auto
c_global_step
=
sfc_c_global
.
GetForwardStep
(
access_id
);
// move on C
c_shuffle_block_copy_lds_to_global
.
MoveDstSliceWindow
(
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_global_step
);
}
});
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -34,8 +34,7 @@ __global__ void
...
@@ -34,8 +34,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
@@ -48,7 +47,7 @@ __global__ void
...
@@ -48,7 +47,7 @@ __global__ void
karg
);
karg
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -63,8 +62,7 @@ __global__ void
...
@@ -63,8 +62,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -81,7 +79,7 @@ __global__ void
...
@@ -81,7 +79,7 @@ __global__ void
karg
);
karg
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
@@ -605,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -605,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
AK0Number
*
MLdsLayer
>
{})),
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
@@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_pass_through_transform
(
AK1Number
)),
...
@@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
BK0Number
*
NLdsLayer
>
{})),
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
@@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_pass_through_transform
(
BK1Number
)),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment