Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
95907384
Unverified
Commit
95907384
authored
Jul 03, 2024
by
Jun Liu
Committed by
GitHub
Jul 03, 2024
Browse files
Fix issue with multiple targets and remove smfmac tests from unsupported test targets (#1372)
parent
497ccb87
Changes
24
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
227 additions
and
44 deletions
+227
-44
Jenkinsfile
Jenkinsfile
+2
-2
client_example/25_wrapper/wrapper_basic_gemm.cpp
client_example/25_wrapper/wrapper_basic_gemm.cpp
+15
-2
client_example/25_wrapper/wrapper_optimized_gemm.cpp
client_example/25_wrapper/wrapper_optimized_gemm.cpp
+14
-2
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+9
-0
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
+9
-0
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
...d_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
+12
-1
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
...d_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
...m_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
..._scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
...oftmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
..._softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
+12
-1
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
...m_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
+12
-1
example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
...v_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
+12
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
...device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
+10
-10
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
.../device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
+15
-14
include/ck/utility/amd_smfmac.hpp
include/ck/utility/amd_smfmac.hpp
+28
-0
test/CMakeLists.txt
test/CMakeLists.txt
+5
-1
test/grouped_convnd_bwd_data/CMakeLists.txt
test/grouped_convnd_bwd_data/CMakeLists.txt
+4
-4
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
..._bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
+8
-0
No files found.
Jenkinsfile
View file @
95907384
...
...
@@ -886,10 +886,10 @@ pipeline {
}
agent
{
label
rocmnode
(
"gfx90a"
)
}
environment
{
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx
908
;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
setup_args
=
""" -DCMAKE_INSTALL_PREFIX=../install -DGPU_TARGETS="gfx
1100
;gfx90a" -DCMAKE_CXX_FLAGS=" -O3 " """
execute_args
=
""" cd ../client_example && rm -rf build && mkdir build && cd build && \
cmake -DCMAKE_PREFIX_PATH="${env.WORKSPACE}/install;/opt/rocm" \
-DGPU_TARGETS="gfx
908
;gfx90a" \
-DGPU_TARGETS="gfx
1100
;gfx90a" \
-DCMAKE_CXX_COMPILER="${build_compiler()}" \
-DCMAKE_CXX_FLAGS=" -O3 " .. && make -j """
}
...
...
client_example/25_wrapper/wrapper_basic_gemm.cpp
View file @
95907384
...
...
@@ -7,19 +7,23 @@
#include <initializer_list>
#include <vector>
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct
SimpleDeviceMem
{
...
...
@@ -204,6 +208,14 @@ void PerformGemm(const ck::index_t M,
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_xdl_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: xdl example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
using
DataType
=
ck
::
half_t
;
const
auto
thread_layout
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
Number
<
64
>
{},
ck
::
Number
<
4
>
{}),
...
...
@@ -213,3 +225,4 @@ int main(int argc, char* argv[])
3840
,
4096
,
4096
,
tile_shape
,
thread_layout
);
return
0
;
}
#endif
client_example/25_wrapper/wrapper_optimized_gemm.cpp
View file @
95907384
...
...
@@ -7,18 +7,21 @@
#include <initializer_list>
#include <vector>
#include "ck/library/utility/host_tensor.hpp"
#include "ck/utility/common_header.hpp"
// __gfx9__ defined in the above header via ck.hpp
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/library/utility/fill.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/wrapper/layout.hpp"
#include "ck/wrapper/tensor.hpp"
#include "ck/wrapper/operations/copy.hpp"
#include "ck/wrapper/operations/gemm.hpp"
#include "ck/wrapper/utils/kernel_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
struct
SimpleDeviceMem
{
...
...
@@ -296,6 +299,14 @@ void PerformGemm(const ck::index_t M,
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_xdl_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: xdl example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
using
DataType
=
ck
::
half_t
;
const
auto
thread_layout
=
ck
::
wrapper
::
make_layout
(
ck
::
make_tuple
(
ck
::
Number
<
4
>
{},
ck
::
Number
<
64
>
{},
ck
::
Number
<
1
>
{}),
...
...
@@ -305,3 +316,4 @@ int main(int argc, char* argv[])
3840
,
4096
,
4096
,
tile_shape
,
thread_layout
);
return
0
;
}
#endif
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
95907384
...
...
@@ -17,6 +17,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/host_utility/device_prop.hpp"
struct
AlphaBetaAdd
{
...
...
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
exit
(
0
);
}
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
...
...
example/02_gemm_bilinear/gemm_bilinear_wmma_int8.cpp
View file @
95907384
...
...
@@ -17,6 +17,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/host_utility/device_prop.hpp"
struct
AlphaBetaAdd
{
...
...
@@ -175,6 +176,14 @@ int main(int argc, char* argv[])
exit
(
0
);
}
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
using
namespace
ck
::
literals
;
...
...
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp
View file @
95907384
...
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using
InKernelDataType
=
FP16
;
...
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
example/30_grouped_conv_fwd_multiple_d/grouped_conv_fwd_bias_relu_add_wmma_int8.cpp
View file @
95907384
...
...
@@ -2,6 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common_wmma.hpp"
#include "ck/host_utility/device_prop.hpp"
// kernel data types
using
InKernelDataType
=
I8
;
...
...
@@ -23,4 +24,14 @@ using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;
#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
!
run_grouped_conv_fwd_bias_relu_add_example
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_lower_triangle_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
95907384
...
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -163,4 +164,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
95907384
...
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -285,4 +286,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_batched_gemm_scale_softmax_gemm_permute_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/cross_attention_forward_wmma_fp16.cpp
View file @
95907384
...
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -351,4 +352,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_cross_attention_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/grouped_query_attention_forward_wmma_fp16.cpp
View file @
95907384
...
...
@@ -28,6 +28,7 @@ Example is GQA-4
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -299,4 +300,14 @@ using ReferenceGemm1Instance =
#include "run_grouped_query_attention_forward_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/multi_query_attention_forward_wmma_fp16.cpp
View file @
95907384
...
...
@@ -26,6 +26,7 @@ Shazeer, Noam. “Fast Transformer Decoding: One Write-Head Is All You Need.”
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -284,4 +285,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm_
#include "run_multi_query_attention_forward_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/32_batched_gemm_scale_softmax_gemm/self_attention_forward_wmma_fp16.cpp
View file @
95907384
...
...
@@ -27,6 +27,7 @@ Gemm + Softmax + Gemm fused operation. Computes C_g_m_n = Softmax(A_g_m_k * B0_g
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batched_gemm.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_softmax.hpp"
#include "ck/host_utility/device_prop.hpp"
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
...
...
@@ -329,4 +330,14 @@ using ReferenceGemm1Instance = ck::tensor_operation::host::ReferenceBatchedGemm<
#include "run_self_attention_wmma.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run
(
argc
,
argv
);
}
example/38_grouped_conv_bwd_data_multiple_d/grouped_conv_bwd_data_wmma_fp16.cpp
View file @
95907384
...
...
@@ -3,6 +3,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
#include "common.hpp"
#include "ck/host_utility/device_prop.hpp"
using
OutDataType
=
FP16
;
using
WeiDataType
=
FP16
;
...
...
@@ -31,4 +32,14 @@ using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDat
#include "run_grouped_conv_bwd_data_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
run_grouped_conv_bwd_data_example
(
argc
,
argv
);
}
int
main
(
int
argc
,
char
*
argv
[])
{
bool
is_supported
=
ck
::
is_gfx11_supported
();
if
(
!
is_supported
)
{
std
::
cout
<<
"WARNING: wmma example not supported on the platform "
<<
ck
::
get_device_name
()
<<
std
::
endl
;
return
0
;
}
return
run_grouped_conv_bwd_data_example
(
argc
,
argv
);
}
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_xdl_cshuffle.hpp
View file @
95907384
...
...
@@ -47,12 +47,12 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[
maybe_unused
]]
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
[[
maybe_unused
]]
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
[[
maybe_unused
]]
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
num_k_per_block
)
[[
maybe_unused
]]
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
[[
maybe_unused
]]
const
index_t
num_k_per_block
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
...
...
@@ -103,12 +103,12 @@ __global__ void
#endif
kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[
maybe_unused
]]
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
[[
maybe_unused
]]
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
[[
maybe_unused
]]
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
index_t
num_k_per_block
)
[[
maybe_unused
]]
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
[[
maybe_unused
]]
const
index_t
num_k_per_block
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp
View file @
95907384
...
...
@@ -69,14 +69,15 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
MinimumOccupancy
)
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_n
,
const
index_t
groups_count
)
kernel_grouped_conv_fwd_xdl_cshuffle_v3
(
typename
GridwiseGemm
::
Argument
karg
,
[[
maybe_unused
]]
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
[[
maybe_unused
]]
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
[[
maybe_unused
]]
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
[[
maybe_unused
]]
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
[[
maybe_unused
]]
const
ComputePtrOffset
compute_ptr_offset_of_n
,
[[
maybe_unused
]]
const
index_t
groups_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
...
...
@@ -132,13 +133,13 @@ __global__ void
#endif
kernel_grouped_conv_fwd_xdl_cshuffle_v3_2lds
(
typename
GridwiseGemm
::
Argument
karg
,
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
[[
maybe_unused
]]
const
AGridDesc_AK0_M_K1
a_grid_desc_ak0_m_ak1
,
[[
maybe_unused
]]
const
BGridDesc_BK0_N_K1
b_grid_desc_bk0_n_bk1
,
[[
maybe_unused
]]
const
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
const
ComputePtrOffset
compute_ptr_offset_of_n
,
const
index_t
groups_count
)
[[
maybe_unused
]]
const
ComputePtrOffset
compute_ptr_offset_of_groups
,
[[
maybe_unused
]]
const
ComputePtrOffset
compute_ptr_offset_of_n
,
[[
maybe_unused
]]
const
index_t
groups_count
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
// offset base pointer for each work-group
...
...
include/ck/utility/amd_smfmac.hpp
View file @
95907384
...
...
@@ -16,8 +16,15 @@ struct intrin_smfmac_f32_16x16x32f16<16, 16>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -31,8 +38,15 @@ struct intrin_smfmac_f32_16x16x32bf16<16, 16>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float4_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_16x16x32_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float4_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -46,8 +60,15 @@ struct intrin_smfmac_f32_32x32x16f16<32, 32>
__device__
static
void
Run
(
const
half4_t
&
reg_a
,
const
half8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_f16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
@@ -61,8 +82,15 @@ struct intrin_smfmac_f32_32x32x16bf16<32, 32>
__device__
static
void
Run
(
const
bhalf4_t
&
reg_a
,
const
bhalf8_t
&
reg_b
,
const
int32_t
&
reg_idx
,
FloatC
&
reg_c
)
{
#if defined(__gfx94__)
reg_c
.
template
AsType
<
float16_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_smfmac_f32_32x32x16_bf16
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float16_t
>()[
Number
<
0
>
{}],
reg_idx
,
0
,
0
);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
ignore
=
reg_idx
;
#endif
}
};
...
...
test/CMakeLists.txt
View file @
95907384
...
...
@@ -71,6 +71,8 @@ function(add_test_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -150,6 +152,8 @@ function(add_gtest_executable TEST_NAME)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"_smfmac"
)
list
(
REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx908 gfx90a
)
endif
()
set_source_files_properties
(
${
ARGN
}
PROPERTIES LANGUAGE HIP
)
add_executable
(
${
TEST_NAME
}
${
ARGN
}
)
...
...
@@ -209,7 +213,7 @@ add_subdirectory(wrapper)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx942"
)
if
(
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
test/grouped_convnd_bwd_data/CMakeLists.txt
View file @
95907384
...
...
@@ -2,11 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data_x
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_xdl
test_grouped_convnd_bwd_data_interface_xdl.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_xdl
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp
)
add_gtest_executable
(
test_grouped_convnd_bwd_data_interface
_wmma
test_grouped_convnd_bwd_data_interface_wmma.cpp
)
if
(
result EQUAL 0
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
target_link_libraries
(
test_grouped_convnd_bwd_data_interface
_wmma
PRIVATE utility device_grouped_conv2d_bwd_data_instance
)
endif
()
test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data_interface_wmma.cpp
View file @
95907384
...
...
@@ -52,6 +52,14 @@ class TestGroupedConvndBwdData : public ::testing::Test
ck
::
utils
::
conv
::
ConvParam
conv_param
;
void
SetUp
()
override
{
if
(
!
ck
::
is_gfx11_supported
())
{
GTEST_SKIP
();
}
}
template
<
ck
::
index_t
NDimSpatial
>
bool
Run
()
{
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment