Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e4112de7
Commit
e4112de7
authored
May 22, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
a6ef5c39
fd72380a
Changes
39
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2813 additions
and
736 deletions
+2813
-736
CMakeLists.txt
CMakeLists.txt
+10
-6
client_example/07_grouped_convnd_fwd/CMakeLists.txt
client_example/07_grouped_convnd_fwd/CMakeLists.txt
+19
-1
client_example/07_grouped_convnd_fwd/common.hpp
client_example/07_grouped_convnd_fwd/common.hpp
+304
-0
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
+11
-201
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
+11
-169
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp
..._example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp
...mple/07_grouped_convnd_fwd/grouped_conv3d_fwd_bf8_fp8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp
..._example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8.cpp
+0
-0
client_example/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp
...mple/07_grouped_convnd_fwd/grouped_conv3d_fwd_fp8_bf8.cpp
+0
-0
client_example/16_convnd_fwd/CMakeLists.txt
client_example/16_convnd_fwd/CMakeLists.txt
+0
-16
example/CMakeLists.txt
example/CMakeLists.txt
+31
-4
include/ck/host_utility/flush_cache.hpp
include/ck/host_utility/flush_cache.hpp
+14
-9
include/ck/tensor_description/multi_index_transform.hpp
include/ck/tensor_description/multi_index_transform.hpp
+11
-4
include/ck/tensor_description/multi_index_transform_helper.hpp
...de/ck/tensor_description/multi_index_transform_helper.hpp
+8
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+1
-2
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+1010
-304
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-2
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
.../tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
...ion/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
+1369
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+11
-13
No files found.
CMakeLists.txt
View file @
e4112de7
...
@@ -23,7 +23,7 @@ endif()
...
@@ -23,7 +23,7 @@ endif()
set
(
version 1.1.0
)
set
(
version 1.1.0
)
# Check support for CUDA/HIP in Cmake
# Check support for CUDA/HIP in Cmake
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX
)
project
(
composable_kernel VERSION
${
version
}
LANGUAGES CXX
HIP
)
include
(
CTest
)
include
(
CTest
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
find_package
(
Python3 3.6 COMPONENTS Interpreter REQUIRED
)
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
...
@@ -112,7 +112,7 @@ message("checking which targets are supported")
#Setting GPU_TARGETS on command line will override this list
#Setting GPU_TARGETS on command line will override this list
if
(
NOT PROFILER_ONLY
)
if
(
NOT PROFILER_ONLY
)
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
rocm_check_target_ids
(
DEFAULT_GPU_TARGETS
TARGETS
"gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
TARGETS
"
gfx900;gfx906;
gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100;gfx1101;gfx1102"
)
else
()
else
()
add_definitions
(
-DPROFILER_ONLY
)
add_definitions
(
-DPROFILER_ONLY
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
set
(
GPU_TARGETS
""
CACHE STRING
""
FORCE
)
...
@@ -135,12 +135,10 @@ endif()
...
@@ -135,12 +135,10 @@ endif()
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
message
(
"Supported GPU_TARGETS=
${
DEFAULT_GPU_TARGETS
}
"
)
set
(
AMDGPU_TARGETS
"
${
DEFAULT_GPU_TARGETS
}
"
CACHE STRING
" "
FORCE
)
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
message
(
"Building CK for the following targets:
${
GPU_TARGETS
}
"
)
message
(
"Building CK for the following targets:
${
GPU_TARGETS
}
"
)
else
()
else
()
message
(
"Building CK for the
following
targets:
${
AMD
GPU_TARGETS
}
"
)
message
(
"Building CK for the
default
targets:
${
DEFAULT_
GPU_TARGETS
}
"
)
endif
()
endif
()
if
(
GPU_TARGETS
)
if
(
GPU_TARGETS
)
...
@@ -225,7 +223,13 @@ link_libraries(Threads::Threads)
...
@@ -225,7 +223,13 @@ link_libraries(Threads::Threads)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD 17
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_STANDARD_REQUIRED ON
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
set
(
CMAKE_CXX_EXTENSIONS OFF
)
message
(
"CMAKE_CXX_COMPILER_ID:
${
CMAKE_CXX_COMPILER_ID
}
"
)
message
(
"CMAKE_CXX_COMPILER:
${
CMAKE_CXX_COMPILER
}
"
)
## HIP
set
(
CMAKE_HIP_PLATFORM amd
)
set
(
CMAKE_HIP_COMPILER
${
CMAKE_CXX_COMPILER
}
)
set
(
CMAKE_HIP_EXTENSIONS ON
)
message
(
"CMAKE_HIP_COMPILER:
${
CMAKE_HIP_COMPILER
}
"
)
## OpenMP
## OpenMP
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
if
(
CMAKE_CXX_COMPILER_ID MATCHES
"Clang"
)
...
...
client_example/07_grouped_convnd_fwd/CMakeLists.txt
View file @
e4112de7
...
@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
...
@@ -4,4 +4,22 @@ if(GPU_TARGETS MATCHES "gfx9")
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
add_executable
(
client_grouped_conv1d_fwd grouped_conv1d_fwd.cpp
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_grouped_conv1d_fwd PRIVATE composable_kernel::device_conv_operations
)
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_fp8 grouped_conv3d_fwd_fp8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_bf8 grouped_conv3d_fwd_bf8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_grouped_conv3d_fwd_fp8_bf8 grouped_conv3d_fwd_fp8_bf8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_grouped_conv3d_fwd_bf8_fp8 grouped_conv3d_fwd_bf8_fp8.cpp
)
target_link_libraries
(
client_grouped_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
endif
()
endif
()
client_example/07_grouped_convnd_fwd/common.hpp
0 → 100644
View file @
e4112de7
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <string>
#include <vector>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
template
<
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetFlops
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
,
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
)
{
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
ck
::
index_t
G
=
weights_lengths
[
0
];
ck
::
index_t
N
=
output_lengths
[
1
];
ck
::
index_t
K
=
weights_lengths
[
1
];
ck
::
index_t
C
=
weights_lengths
[
2
];
return
static_cast
<
std
::
size_t
>
(
2
)
*
G
*
N
*
K
*
C
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
output_lengths
),
NumNonSpatialDim
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
())
*
std
::
accumulate
(
std
::
next
(
std
::
begin
(
weights_lengths
),
NumNonSpatialDim
),
std
::
end
(
weights_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
InDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetInputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
input_lengths
)
{
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
return
sizeof
(
InDataType
)
*
std
::
accumulate
(
std
::
begin
(
input_lengths
),
std
::
end
(
input_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
WeiDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetWeightByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
weights_lengths
)
{
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
return
sizeof
(
WeiDataType
)
*
std
::
accumulate
(
std
::
begin
(
weights_lengths
),
std
::
end
(
weights_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<>
());
}
template
<
typename
OutDataType
,
ck
::
index_t
NumDimSpatial
,
ck
::
index_t
NumNonSpatialDim
=
3
>
std
::
size_t
GetOutputByte
(
const
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>&
output_lengths
)
{
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
return
sizeof
(
OutDataType
)
*
std
::
accumulate
(
std
::
begin
(
output_lengths
),
std
::
end
(
output_lengths
),
static_cast
<
std
::
size_t
>
(
1
),
std
::
multiplies
<
std
::
size_t
>
());
}
template
<
ck
::
index_t
NumDimSpatial
,
typename
InDataType
,
typename
WeiDataType
,
typename
OutDataType
,
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
,
ck
::
index_t
NumNonSpatialDim
=
3
,
typename
AComputeType
=
InDataType
,
typename
BComputeType
=
AComputeType
>
bool
run_grouped_conv_fwd
(
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
in_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
wei_lengths
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
out_lengths
)
{
std
::
size_t
in_mem_size
=
GetInputByte
<
InDataType
,
NumDimSpatial
>
(
in_lengths
);
std
::
size_t
wei_mem_size
=
GetWeightByte
<
WeiDataType
,
NumDimSpatial
>
(
wei_lengths
);
std
::
size_t
out_mem_size
=
GetOutputByte
<
OutDataType
,
NumDimSpatial
>
(
out_lengths
);
SimpleDeviceMem
in
(
in_mem_size
);
SimpleDeviceMem
wei
(
wei_mem_size
);
SimpleDeviceMem
out
(
out_mem_size
);
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
in_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
wei_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
out_strides
;
in_strides
.
fill
(
0
);
wei_strides
.
fill
(
0
);
out_strides
.
fill
(
0
);
in_strides
.
back
()
=
1
;
wei_strides
.
back
()
=
1
;
out_strides
.
back
()
=
1
;
std
::
partial_sum
(
rbegin
(
in_lengths
),
std
::
prev
(
rend
(
in_lengths
)),
std
::
next
(
rbegin
(
in_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
wei_lengths
),
std
::
prev
(
rend
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
out_lengths
),
std
::
prev
(
rend
(
out_lengths
)),
std
::
next
(
rbegin
(
out_strides
)),
std
::
multiplies
<>
{});
// transpose NDHWGC/KZYXGC/NDHWGK to GNDHWC/GKZYXC/GNDHWK to GNCDHW/GKCZYX/GNKDHW
std
::
rotate
(
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
2
),
rend
(
in_lengths
));
std
::
rotate
(
rbegin
(
in_lengths
),
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
2
),
rend
(
in_strides
));
std
::
rotate
(
rbegin
(
in_strides
),
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
2
),
rend
(
out_lengths
));
std
::
rotate
(
rbegin
(
out_lengths
),
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
2
),
rend
(
out_strides
));
std
::
rotate
(
rbegin
(
out_strides
),
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
NumDimSpatial
+
1
));
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_strides
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
conv_filter_dilations
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
;
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
;
conv_filter_strides
.
fill
(
1
);
conv_filter_dilations
.
fill
(
1
);
input_left_pads
.
fill
(
1
);
input_right_pads
.
fill
(
1
);
std
::
size_t
flop
=
GetFlops
<
NumDimSpatial
>
(
out_lengths
,
wei_lengths
);
std
::
size_t
num_bytes
=
in_mem_size
+
wei_mem_size
+
out_mem_size
;
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
,
AComputeType
,
BComputeType
>
;
// get device op instances
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
false
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
std
::
array
<
const
void
*
,
0
>
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
std
::
array
<
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
NumNonSpatialDim
>
,
0
>
{{}},
out_lengths
,
out_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
true
;
}
client_example/07_grouped_convnd_fwd/grouped_conv1d_fwd.cpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "common.hpp"
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
...
@@ -31,199 +24,16 @@ static constexpr ck::index_t X = 3;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wi
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
static
constexpr
ck
::
index_t
Wo
=
28
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
int
main
()
{
{
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
in_lengths
{
G
,
N
,
Wi
,
C
};
return
run_grouped_conv_fwd
<
NumDimSpatial
,
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
in_strides
{
0
,
0
,
0
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
wei_lengths
{
G
,
K
,
X
,
C
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
wei_strides
{
0
,
0
,
0
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
out_lengths
{
G
,
N
,
Wo
,
K
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
+
3
>
out_strides
{
0
,
0
,
0
,
1
};
std
::
partial_sum
(
rbegin
(
in_lengths
),
std
::
prev
(
rend
(
in_lengths
)),
std
::
next
(
rbegin
(
in_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
wei_lengths
),
std
::
prev
(
rend
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
multiplies
<>
{});
std
::
partial_sum
(
rbegin
(
out_lengths
),
std
::
prev
(
rend
(
out_lengths
)),
std
::
next
(
rbegin
(
out_strides
)),
std
::
multiplies
<>
{});
// transpose GNWC/GKXC/GNWK to GNCW/GKCX/GNCW
std
::
rotate
(
rbegin
(
in_lengths
),
std
::
next
(
rbegin
(
in_lengths
)),
std
::
next
(
rbegin
(
in_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
in_strides
),
std
::
next
(
rbegin
(
in_strides
)),
std
::
next
(
rbegin
(
in_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_lengths
),
std
::
next
(
rbegin
(
wei_lengths
)),
std
::
next
(
rbegin
(
wei_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
wei_strides
),
std
::
next
(
rbegin
(
wei_strides
)),
std
::
next
(
rbegin
(
wei_strides
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
out_lengths
),
std
::
next
(
rbegin
(
out_lengths
)),
std
::
next
(
rbegin
(
out_lengths
),
NumDimSpatial
+
1
));
std
::
rotate
(
rbegin
(
out_strides
),
std
::
next
(
rbegin
(
out_strides
)),
std
::
next
(
rbegin
(
out_strides
),
NumDimSpatial
+
1
));
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
};
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
G
*
N
*
Wi
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
OutDataType
,
PassThrough
,
InLayout
,
PassThrough
,
WeiLayout
,
PassThrough
>
;
OutLayout
,
3
>
({
N
,
Wi
,
G
,
C
},
{
G
,
K
,
X
,
C
},
{
N
,
Wo
,
G
,
K
})
// get device op instances
?
EXIT_SUCCESS
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
:
EXIT_FAILURE
;
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
N
*
K
*
C
*
Wo
*
X
;
std
::
size_t
num_bytes
=
sizeof
(
InDataType
)
*
G
*
N
*
Wi
*
C
+
sizeof
(
WeiDataType
)
*
G
*
K
*
X
*
C
+
sizeof
(
OutDataType
)
*
G
*
N
*
Wo
*
K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
}
client_example/07_grouped_convnd_fwd/grouped_conv2d_fwd.cpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "common.hpp"
#include <iomanip>
#include <iostream>
#include <iterator>
#include <numeric>
#include <vector>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/gpu/grouped_convolution_forward.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
using
InDataType
=
ck
::
half_t
;
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
...
@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
...
@@ -34,167 +27,16 @@ static constexpr ck::index_t Wi = 28; // input W
static
constexpr
ck
::
index_t
Ho
=
28
;
// output H
static
constexpr
ck
::
index_t
Ho
=
28
;
// output H
static
constexpr
ck
::
index_t
Wo
=
28
;
// output W
static
constexpr
ck
::
index_t
Wo
=
28
;
// output W
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
()
int
main
()
{
{
// We have NHWGC/GKYXC/NHWGK (x, weight, y) in memory space
return
run_grouped_conv_fwd
<
NumDimSpatial
,
// However, CK's API only accept length and stride with order of GNCHW/GKCYX/GNCHW
// Hence, we need to adjust the order of stride
std
::
array
<
ck
::
index_t
,
5
>
in_lengths
{
G
,
N
,
C
,
Hi
,
Wi
};
std
::
array
<
ck
::
index_t
,
5
>
in_strides
{
C
,
Hi
*
Wi
*
G
*
C
,
1
,
Wi
*
G
*
C
,
G
*
C
};
std
::
array
<
ck
::
index_t
,
5
>
wei_lengths
{
G
,
K
,
C
,
Y
,
X
};
std
::
array
<
ck
::
index_t
,
5
>
wei_strides
{
K
*
Y
*
X
*
C
,
Y
*
X
*
C
,
1
,
X
*
C
,
C
};
std
::
array
<
ck
::
index_t
,
5
>
out_lengths
{
G
,
N
,
K
,
Ho
,
Wo
};
std
::
array
<
ck
::
index_t
,
5
>
out_strides
{
C
,
Ho
*
Wo
*
G
*
C
,
1
,
Wo
*
G
*
C
,
G
*
C
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_strides
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
filter_dilations
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_left_pads
{
1
,
1
};
std
::
array
<
ck
::
index_t
,
NumDimSpatial
>
input_right_pads
{
1
,
1
};
SimpleDeviceMem
in
(
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
);
SimpleDeviceMem
wei
(
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
);
SimpleDeviceMem
out
(
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvFwdMultipleABD
<
NumDimSpatial
,
InLayout
,
WeiLayout
,
ck
::
Tuple
<>
,
OutLayout
,
InDataType
,
InDataType
,
WeiDataType
,
WeiDataType
,
ck
::
Tuple
<>
,
OutDataType
,
OutDataType
,
PassThrough
,
InLayout
,
PassThrough
,
WeiLayout
,
PassThrough
>
;
OutLayout
,
3
>
({
N
,
Hi
,
Wi
,
G
,
C
},
{
G
,
K
,
Y
,
X
,
C
},
{
N
,
Ho
,
Wo
,
G
,
K
})
// get device op instances
?
EXIT_SUCCESS
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
:
EXIT_FAILURE
;
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
int
best_op_id
=
-
1
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
float
best_tflops
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
flop
=
std
::
size_t
(
2
)
*
G
*
N
*
K
*
C
*
Ho
*
Wo
*
Y
*
X
;
std
::
size_t
num_bytes
=
sizeof
(
InDataType
)
*
N
*
Hi
*
Wi
*
G
*
C
+
sizeof
(
WeiDataType
)
*
G
*
K
*
Y
*
X
*
C
+
sizeof
(
OutDataType
)
*
N
*
Ho
*
Wo
*
G
*
K
;
float
tflops
=
static_cast
<
float
>
(
flop
)
/
1.E9
/
avg_time
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
avg_time
<<
" ms, "
<<
tflops
<<
" TFlops, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
tflops
>
best_tflops
)
{
best_op_id
=
i
;
best_op_name
=
op_name
;
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
best_tflops
=
tflops
;
}
}
else
{
std
::
cerr
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
best_op_id
<
0
)
{
std
::
cerr
<<
"no suitable instance"
<<
std
::
endl
;
return
EXIT_FAILURE
;
}
std
::
cout
<<
"Best Perf: "
<<
std
::
setw
(
10
)
<<
best_avg_time
<<
" ms, "
<<
best_tflops
<<
" TFlops, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
{
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
in
.
GetDeviceBuffer
(),
wei
.
GetDeviceBuffer
(),
{},
out
.
GetDeviceBuffer
(),
in_lengths
,
in_strides
,
wei_lengths
,
wei_strides
,
{},
{},
out_lengths
,
out_strides
,
filter_strides
,
filter_dilations
,
input_left_pads
,
input_right_pads
,
PassThrough
{},
PassThrough
{},
PassThrough
{});
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
}
}
client_example/
16
_convnd_fwd/conv3d_fwd_bf8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_bf8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_bf8_fp8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_bf8_fp8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_fp8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_fp8.cpp
View file @
e4112de7
File moved
client_example/
16
_convnd_fwd/conv3d_fwd_fp8_bf8.cpp
→
client_example/
07_grouped
_convnd_fwd/
grouped_
conv3d_fwd_fp8_bf8.cpp
View file @
e4112de7
File moved
client_example/16_convnd_fwd/CMakeLists.txt
View file @
e4112de7
...
@@ -7,22 +7,6 @@ endif()
...
@@ -7,22 +7,6 @@ endif()
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
if
((
DTYPES MATCHES
"fp8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp
)
add_executable
(
client_conv3d_fwd_fp16_comp_fp8 conv3d_fwd_fp16_comp_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations
)
target_link_libraries
(
client_conv3d_fwd_fp16_comp_fp8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_fp8 conv3d_fwd_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_bf8 conv3d_fwd_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_bf8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
if
((
DTYPES MATCHES
"fp8"
AND DTYPES MATCHES
"bf8"
)
OR NOT DEFINED DTYPES
)
add_executable
(
client_conv3d_fwd_fp8_bf8 conv3d_fwd_fp8_bf8.cpp
)
target_link_libraries
(
client_conv3d_fwd_fp8_bf8 PRIVATE composable_kernel::device_conv_operations
)
add_executable
(
client_conv3d_fwd_bf8_fp8 conv3d_fwd_bf8_fp8.cpp
)
target_link_libraries
(
client_conv3d_fwd_bf8_fp8 PRIVATE composable_kernel::device_conv_operations
)
endif
()
endif
()
if
((
DTYPES MATCHES
"fp32"
)
OR NOT DEFINED DTYPES
)
if
((
DTYPES MATCHES
"fp32"
)
OR NOT DEFINED DTYPES
)
...
...
example/CMakeLists.txt
View file @
e4112de7
...
@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -44,6 +44,13 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
EX_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
EX_TARGETS
${
GPU_TARGETS
}
)
endif
()
#Do not build any DL examples if DL_KERNELS not set
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
...
@@ -53,23 +60,30 @@ function(add_example_executable EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
if
(
FILE_NAME
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
add_test
(
NAME
${
EXAMPLE_NAME
}
COMMAND $<TARGET_FILE:
${
EXAMPLE_NAME
}
>
${
ARGN
}
)
set_property
(
TARGET
${
EXAMPLE_NAME
}
PROPERTY HIP_ARCHITECTURES
${
EX_TARGETS
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
add_dependencies
(
check
${
EXAMPLE_NAME
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
...
@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -118,6 +132,12 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endif
()
endif
()
endforeach
()
endforeach
()
endif
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
EX_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
EX_TARGETS
${
GPU_TARGETS
}
)
endif
()
#Do not build any DL examples if DL_KERNELS not set
#Do not build any DL examples if DL_KERNELS not set
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
...
@@ -127,23 +147,30 @@ function(add_example_executable_no_testing EXAMPLE_NAME FILE_NAME)
endforeach
()
endforeach
()
#Do not build any XDL examples if gfx9 targets are not on the list
#Do not build any XDL examples if gfx9 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx9"
AND source MATCHES
"_xdl"
)
message
(
"removing xdl example
${
source
}
"
)
message
(
"removing xdl example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#Do not build any WMMA examples if gfx11 targets are not on the list
#Do not build any WMMA examples if gfx11 targets are not on the list
foreach
(
source IN LISTS FILE_NAME
)
foreach
(
source IN LISTS FILE_NAME
)
if
(
NOT
GPU
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
if
(
NOT
EX
_TARGETS MATCHES
"gfx11"
AND source MATCHES
"_wmma"
)
message
(
"removing wmma example
${
source
}
"
)
message
(
"removing wmma example
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
list
(
REMOVE_ITEM FILE_NAME
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
FILE_NAME
)
if
(
FILE_NAME
)
if
(
FILE_NAME MATCHES
"_xdl"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
FILE_NAME MATCHES
"_wmma"
)
list
(
REMOVE_ITEM EX_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
endif
()
set_source_files_properties
(
${
FILE_NAME
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
add_executable
(
${
EXAMPLE_NAME
}
${
FILE_NAME
}
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
target_link_libraries
(
${
EXAMPLE_NAME
}
PRIVATE utility
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
add_dependencies
(
examples
${
EXAMPLE_NAME
}
)
set_property
(
TARGET
${
EXAMPLE_NAME
}
PROPERTY HIP_ARCHITECTURES
${
EX_TARGETS
}
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
rocm_install
(
TARGETS
${
EXAMPLE_NAME
}
COMPONENT examples
)
set
(
result 0
)
set
(
result 0
)
endif
()
endif
()
...
...
include/ck/host_utility/flush_cache.hpp
View file @
e4112de7
...
@@ -104,14 +104,19 @@ inline void flush_icache()
...
@@ -104,14 +104,19 @@ inline void flush_icache()
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
// if TimePrePress == false, return time does not include preprocess's time
// if TimePrePress == false, return time does not include preprocess's time
template
<
bool
TimePreprocess
,
typename
Args
,
typename
F
,
typename
PreProcessFunc
>
template
<
bool
TimePreprocess
,
typename
GemmArgs
,
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
PreProcessFunc
preprocess
,
F
kernel
,
F
kernel
,
dim3
grid_dim
,
dim3
grid_dim
,
dim3
block_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
std
::
size_t
lds_byte
,
Args
&
args
)
GemmArgs
&
gemm_args
,
Args
...
args
)
{
{
#if CK_TIME_KERNEL
#if CK_TIME_KERNEL
#define MEDIAN 1
#define MEDIAN 1
...
@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -133,7 +138,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
// warm up
// warm up
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
for
(
int
i
=
0
;
i
<
stream_config
.
cold_niters_
;
++
i
)
{
{
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
}
}
...
@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -172,7 +177,7 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
preprocess
();
preprocess
();
}
}
// run real kernel
// run real kernel
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
// end real kernel
// end real kernel
...
@@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -190,9 +195,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
{
{
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
std
::
cout
<<
"i: "
<<
i
<<
" cur_time: "
<<
cur_time
<<
std
::
endl
;
printf
(
"args.p_a_grid: %p, args.p_b_grid:%p
\n
"
,
printf
(
"
gemm_
args.p_a_grid: %p,
gemm_
args.p_b_grid:%p
\n
"
,
static_cast
<
const
void
*>
(
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
gemm_
args
.
p_a_grid
),
static_cast
<
const
void
*>
(
args
.
p_b_grid
));
static_cast
<
const
void
*>
(
gemm_
args
.
p_b_grid
));
}
}
}
}
...
@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
...
@@ -216,13 +221,13 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
else
else
{
{
preprocess
();
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
}
}
#else
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
gemm_args
,
args
...
);
hip_check_error
(
hipGetLastError
());
hip_check_error
(
hipGetLastError
());
return
0
;
return
0
;
...
...
include/ck/tensor_description/multi_index_transform.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -1952,7 +1952,7 @@ struct Modulo
...
@@ -1952,7 +1952,7 @@ struct Modulo
}
}
};
};
template
<
typename
LowLengths
>
template
<
typename
LowLengths
,
bool
ApplyModulo
>
struct
Xor
struct
Xor
{
{
using
LowerIndex
=
MultiIndex
<
2
>
;
using
LowerIndex
=
MultiIndex
<
2
>
;
...
@@ -1981,9 +1981,16 @@ struct Xor
...
@@ -1981,9 +1981,16 @@ struct Xor
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
idx_low
(
Number
<
0
>
{})
=
idx_up
[
Number
<
0
>
{}];
if
constexpr
(
ApplyModulo
)
{
idx_low
(
Number
<
1
>
{})
=
idx_low
(
Number
<
1
>
{})
=
idx_up
[
Number
<
1
>
{}]
^
(
idx_up
[
Number
<
0
>
{}]
%
up_lengths_
[
Number
<
1
>
{}]);
idx_up
[
Number
<
1
>
{}]
^
(
idx_up
[
Number
<
0
>
{}]
%
up_lengths_
[
Number
<
1
>
{}]);
}
}
else
{
idx_low
(
Number
<
1
>
{})
=
idx_up
[
Number
<
1
>
{}]
^
idx_up
[
Number
<
0
>
{}];
}
}
template
<
typename
LowIdxDiff
,
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
UpIdxDiff
,
...
...
include/ck/tensor_description/multi_index_transform_helper.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
...
@@ -128,9 +128,15 @@ __host__ __device__ constexpr auto make_modulo_transform(const Modulus& modulus,
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
return
Modulo
<
Modulus
,
UpLength
>
{
modulus
,
up_length
};
}
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_xor_with_modulo_transform
(
const
LowLengths
&
low_lengths
)
{
return
Xor
<
LowLengths
,
true
/*ApplyModulo*/
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
__host__
__device__
constexpr
auto
make_xor_transform
(
const
LowLengths
&
low_lengths
)
{
{
return
Xor
<
LowLengths
>
{
low_lengths
};
return
Xor
<
LowLengths
,
false
/*ApplyModulo*/
>
{
low_lengths
};
}
}
}
// namespace ck
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
e4112de7
...
@@ -53,8 +53,7 @@ __global__ void
...
@@ -53,8 +53,7 @@ __global__ void
e_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx94__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
e4112de7
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
e4112de7
...
@@ -45,8 +45,7 @@ __global__ void
...
@@ -45,8 +45,7 @@ __global__ void
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
)
const
CDEElementwiseOperation
cde_element_op
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
const
index_t
KBatch
=
1
;
const
index_t
KBatch
=
1
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_fpAintB_gemm_wmma.hpp
View file @
e4112de7
...
@@ -50,8 +50,7 @@ __global__ void
...
@@ -50,8 +50,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__) || defined(__gfx1101__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
defined(__gfx1102__))
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
__shared__
char
p_shared
[
GridwiseGemm
::
SharedMemTrait
::
lds_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
@@ -80,7 +79,7 @@ __global__ void
...
@@ -80,7 +79,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx11
00
__))
#endif // end of if (defined(__gfx11__))
}
}
// Assume B is Col-Major
// Assume B is Col-Major
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
0 → 100644
View file @
e4112de7
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
e4112de7
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -34,8 +34,7 @@ __global__ void
...
@@ -34,8 +34,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
auto
splitk_batch_offset
=
typename
GridwiseGemm
::
SplitKBatchOffset
(
karg
);
...
@@ -48,7 +47,7 @@ __global__ void
...
@@ -48,7 +47,7 @@ __global__ void
karg
);
karg
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
...
@@ -63,8 +62,7 @@ __global__ void
...
@@ -63,8 +62,7 @@ __global__ void
// __attribute__((amdgpu_waves_per_eu(1, 1)))
// __attribute__((amdgpu_waves_per_eu(1, 1)))
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
kernel_gemm_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
// Pass two lds pointer is the key to tell compiler that ds_read/write
// Pass two lds pointer is the key to tell compiler that ds_read/write
// operate on different lds chunk at same time without order dependecy
// operate on different lds chunk at same time without order dependecy
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_0
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -81,7 +79,7 @@ __global__ void
...
@@ -81,7 +79,7 @@ __global__ void
karg
);
karg
);
#else
#else
ignore
=
karg
;
ignore
=
karg
;
#endif // end of if (defined(__gfx9
08__) || defined(__gfx90a
__))
#endif // end of if (defined(__gfx9__))
}
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
...
@@ -605,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -605,8 +603,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
AK0Number
*
MLdsLayer
>
{})),
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
@@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -671,7 +669,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
make_pass_through_transform
(
AK1Number
)),
...
@@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -742,8 +740,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
BK0Number
*
NLdsLayer
>
{})),
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
@@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -805,7 +803,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
make_pass_through_transform
(
BK1Number
)),
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment