Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
50540084
"...include/reference_conv_backward_weight.hpp" did not exist on "880fbee95782a30fb16654f830502d03dd92fae2"
Commit
50540084
authored
Aug 07, 2023
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/gemm_tile_loop
parents
9c4d265f
2474dddb
Changes
176
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
588 additions
and
63 deletions
+588
-63
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
+3
-3
example/01_gemm/gemm_xdl_streamk.cpp
example/01_gemm/gemm_xdl_streamk.cpp
+49
-0
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+141
-28
example/14_gemm_quantization/CMakeLists.txt
example/14_gemm_quantization/CMakeLists.txt
+3
-1
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
+15
-3
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
...uped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
+15
-3
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
...d_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
+6
-0
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
..._gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
+3
-1
example/40_conv2d_fwd_quantization/CMakeLists.txt
example/40_conv2d_fwd_quantization/CMakeLists.txt
+3
-1
example/50_put_element/put_element_fp16.cpp
example/50_put_element/put_element_fp16.cpp
+4
-4
include/ck/ck.hpp
include/ck/ck.hpp
+7
-3
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+7
-0
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+69
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+15
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
...n/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
+164
-0
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+1
-2
include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
...de/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
+64
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+13
-11
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
...pl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
+1
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
...ion/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
+5
-0
No files found.
example/01_gemm/gemm_xdl_skip_b_lds_fp16.cpp
View file @
50540084
...
@@ -204,9 +204,9 @@ int main(int argc, char* argv[])
...
@@ -204,9 +204,9 @@ int main(int argc, char* argv[])
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
{
throw
std
::
runtime_error
(
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
)
;
return
0
;
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
time_kernel
});
...
...
example/01_gemm/gemm_xdl_streamk.cpp
0 → 100644
View file @
50540084
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_streamk.hpp"
using
ADataType
=
ck
::
half_t
;
using
BDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
using
CShuffleDataType
=
float
;
using
CDataType
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
ALayout
=
Row
;
using
BLayout
=
Row
;
// using BLayout = Col;
using
CLayout
=
Row
;
using
AElementOp
=
PassThrough
;
using
BElementOp
=
PassThrough
;
using
CElementOp
=
PassThrough
;
// clang-format off
using
DeviceGemmStreamK
=
ck
::
tensor_operation
::
device
::
DeviceGemmXdlStreamK
// ######| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
// ######| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
// ######| | | | | | | | Operation| Operation| Operation| | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
// ######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
ALayout
,
BLayout
,
CLayout
,
AElementOp
,
BElementOp
,
CElementOp
,
256
,
128
,
128
,
4
,
8
,
32
,
32
,
2
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
2
,
8
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 64, 4, 8, 32, 32, 1, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>;
// < ADataType, BDataType, CDataType, AccDataType, ALayout, BLayout, CLayout, AElementOp, BElementOp, CElementOp, 128, 32, 128, 4, 8, 32, 32, 1, 1, S<8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 2, 1, 1, 1, S<1, 32, 1, 4>, 8>;
// // clang-format on
// clang-format on
using
DeviceGemmInstance
=
DeviceGemmStreamK
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
#include "run_gemm_example.inc"
int
main
(
int
argc
,
char
*
argv
[])
{
return
!
run_gemm_streamk_example
(
argc
,
argv
);
}
example/01_gemm/run_gemm_example.inc
View file @
50540084
...
@@ -3,7 +3,10 @@
...
@@ -3,7 +3,10 @@
#pragma once
#pragma once
bool
run_gemm
(
const
ProblemSize
&
problem_size
,
const
ExecutionConfig
&
config
)
#include "ck/tensor_operation/gpu/device/device_gemm_streamk.hpp"
template
<
typename
ProblemType
>
bool
run_gemm
(
const
ProblemType
&
problem_size
,
const
ExecutionConfig
&
config
)
{
{
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
#if defined(BUILD_INT4_EXAMPLE) && defined(CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4)
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
static_assert
(
sizeof
(
ck
::
int4_t
)
==
sizeof
(
int8_t
));
...
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -11,7 +14,12 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
using
namespace
ck
::
literals
;
using
namespace
ck
::
literals
;
auto
&
[
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
]
=
problem_size
;
auto
M
=
problem_size
.
M
;
auto
N
=
problem_size
.
N
;
auto
K
=
problem_size
.
K
;
auto
StrideA
=
problem_size
.
StrideA
;
auto
StrideB
=
problem_size
.
StrideB
;
auto
StrideC
=
problem_size
.
StrideC
;
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
...
@@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -25,12 +33,37 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
}
}
};
};
auto
f_get_default_stride
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
stride
==
0
)
{
// give a chance if stride is zero, return a default packed stride
if
constexpr
(
std
::
is_same_v
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
col
;
}
else
{
return
row
;
}
}
else
return
stride
;
};
StrideA
=
f_get_default_stride
(
M
,
K
,
StrideA
,
ALayout
{});
StrideB
=
f_get_default_stride
(
K
,
N
,
StrideB
,
BLayout
{});
StrideC
=
f_get_default_stride
(
M
,
N
,
StrideC
,
CLayout
{});
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
M
,
K
,
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
K
,
N
,
StrideB
,
BLayout
{}));
switch
(
config
.
init_method
)
switch
(
config
.
init_method
)
{
{
case
0
:
break
;
case
0
:
ck
::
utils
::
FillConstant
<
ADataType
>
{
static_cast
<
ADataType
>
(
1.
f
)}(
a_m_k
);
ck
::
utils
::
FillConstant
<
BDataType
>
{
static_cast
<
BDataType
>
(
1.
f
)}(
b_k_n
);
break
;
case
1
:
case
1
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5.
f
,
5.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
...
@@ -66,43 +99,115 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -66,43 +99,115 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
#endif
#endif
DeviceMem
workspace
;
auto
a_element_op
=
AElementOp
{};
auto
a_element_op
=
AElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
b_element_op
=
BElementOp
{};
auto
c_element_op
=
CElementOp
{};
auto
c_element_op
=
CElementOp
{};
using
BaseStreamK
=
ck
::
tensor_operation
::
device
::
DeviceGemmStreamK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
// do GEMM
// do GEMM
auto
gemm
=
DeviceGemmInstance
{};
auto
gemm
=
DeviceGemmInstance
{};
auto
invoker
=
gemm
.
MakeInvoker
();
auto
invoker
=
gemm
.
MakeInvoker
();
auto
argument
=
gemm
.
MakeArgument
(
float
ave_time
=
0
;
if
constexpr
(
std
::
is_same
<
ProblemType
,
ProblemSize
>::
value
&&
!
std
::
is_base_of
<
BaseStreamK
,
DeviceGemmInstance
>::
value
)
{
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
#ifdef BUILD_INT4_EXAMPLE
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
#endif
M
,
M
,
N
,
N
,
K
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
}
else
if
constexpr
(
std
::
is_same
<
ProblemType
,
ProblemSizeStreamK
>::
value
&&
std
::
is_base_of
<
BaseStreamK
,
DeviceGemmInstance
>::
value
)
{
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
auto
argument
=
gemm
.
MakeArgument
(
#ifdef BUILD_INT4_EXAMPLE
return
true
;
static_cast
<
KernelADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelBDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
KernelCDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#else
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
#endif
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
,
problem_size
.
NumSKBlocks
);
if
(
!
gemm
.
IsSupportedArgument
(
argument
))
{
std
::
cerr
<<
gemm
.
GetTypeString
()
<<
" does not support this problem"
<<
std
::
endl
;
return
true
;
}
std
::
size_t
workspace_size
=
gemm
.
GetWorkSpaceSize
(
&
argument
);
if
(
workspace_size
!=
0
)
{
workspace
.
Realloc
(
workspace_size
);
gemm
.
SetWorkSpacePointer
(
&
argument
,
workspace
.
GetDeviceBuffer
());
}
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
#if 0
// TODO!!!!!
if
(
workspace_size
!=
0
){
float
*
ws_ptr
=
reinterpret_cast
<
float
*>
(
malloc
(
workspace_size
));
size_t
ws_dwords
=
workspace_size
/
sizeof
(
float
);
workspace
.
FromDevice
(
ws_ptr
);
for
(
size_t
i
=
0
;
i
<
ws_dwords
;
i
++
)
{
uint32_t
rere
=
reinterpret_cast
<
uint32_t
*>
(
ws_ptr
)[
i
];
printf
(
"%4lu : %f(0x%08x)
\n
"
,
i
,
ws_ptr
[
i
],
rere
);
}
free
(
ws_ptr
);
}
#endif
}
}
float
ave_time
=
invoker
.
Run
(
argument
,
StreamConfig
{
nullptr
,
config
.
time_kernel
});
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
flop
=
2_
uz
*
M
*
N
*
K
;
std
::
size_t
num_btype
=
std
::
size_t
num_btype
=
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
sizeof
(
ADataType
)
*
M
*
K
+
sizeof
(
BDataType
)
*
K
*
N
+
sizeof
(
CDataType
)
*
M
*
N
;
...
@@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[])
...
@@ -149,3 +254,11 @@ bool run_gemm_example(int argc, char* argv[])
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
}
bool
run_gemm_streamk_example
(
int
argc
,
char
*
argv
[])
{
ProblemSizeStreamK
problem_size
;
ExecutionConfig
config
;
return
!
parse_cmd_args
(
argc
,
argv
,
problem_size
,
config
)
||
run_gemm
(
problem_size
,
config
);
}
example/14_gemm_quantization/CMakeLists.txt
View file @
50540084
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
# dlops
# dlops
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_dl_quantization_int8 gemm_dl_quantization_int8.cpp
)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -10,4 +11,5 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
add_example_executable
(
example_gemm_xdl_quantization_int8 gemm_xdl_quantization_int8.cpp
)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
\ No newline at end of file
endif
()
\ No newline at end of file
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_bf16.cpp
View file @
50540084
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
BF16
;
using
InDataType
=
BF16
;
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
// bf16 kernel use fp32 atomic add to accumulate Weight tensor into global memory
...
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
...
@@ -17,8 +17,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/grouped_conv_bwd_weight_xdl_fp16.cpp
View file @
50540084
...
@@ -3,7 +3,7 @@
...
@@ -3,7 +3,7 @@
#include "common.hpp"
#include "common.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_
gnwc_gkxc_gnwk_
xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp"
using
InDataType
=
F16
;
using
InDataType
=
F16
;
using
WeiDataType
=
F16
;
using
WeiDataType
=
F16
;
...
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
...
@@ -16,8 +16,20 @@ using OutElementOp = PassThrough;
template
<
ck
::
index_t
NDimSpatial
>
template
<
ck
::
index_t
NDimSpatial
>
using
DeviceConvBwdWeightInstance
=
using
DeviceConvBwdWeightInstance
=
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
<
ck
::
tensor_operation
::
device
::
DeviceGroupedConvBwdWeight_Xdl_CShuffle
<
NDimSpatial
,
// NDimSpatial
NDimSpatial
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWC
,
ck
::
tensor_layout
::
convolution
::
GNHWC
,
ck
::
tensor_layout
::
convolution
::
GNDHWC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GKXC
,
ck
::
tensor_layout
::
convolution
::
GKYXC
,
ck
::
tensor_layout
::
convolution
::
GKZYXC
>>
,
ck
::
tuple_element_t
<
NDimSpatial
-
1
,
ck
::
Tuple
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
>>
,
InDataType
,
// InDataType
InDataType
,
// InDataType
WeiDataType
,
// WeiDataType
WeiDataType
,
// WeiDataType
OutDataType
,
// OutDataType
OutDataType
,
// OutDataType
...
...
example/20_grouped_conv_bwd_weight/run_grouped_conv_bwd_weight_example.inc
View file @
50540084
...
@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -75,6 +75,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
input_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>
output_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
{};
...
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -85,6 +87,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
input_spatial_lengths_
,
begin
(
input_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
filter_spatial_lengths_
,
begin
(
filter_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
conv_param
.
output_spatial_lengths_
,
begin
(
output_spatial_lengths
));
range_copy
(
in_g_n_c_wis_desc
.
GetStrides
(),
begin
(
input_strides
));
range_copy
(
out_g_n_k_wos_desc
.
GetStrides
(),
begin
(
output_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_strides_
,
begin
(
conv_filter_strides
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
conv_filter_dilations_
,
begin
(
conv_filter_dilations
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
range_copy
(
conv_param
.
input_left_pads_
,
begin
(
input_left_pads
));
...
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
...
@@ -103,6 +107,8 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
input_spatial_lengths
,
input_spatial_lengths
,
filter_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
example/37_batched_gemm_add_add_relu_gemm_add/batched_gemm_add_add_relu_gemm_add_xdl_fp16.cpp
View file @
50540084
...
@@ -173,6 +173,8 @@ using DeviceGemmInstance =
...
@@ -173,6 +173,8 @@ using DeviceGemmInstance =
8
,
8
,
8
,
8
,
true
,
true
,
9
,
// D0sTransferSrcVectorDim
4
,
// D0sTransferSrcScalaerPerVector
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
8
,
32
,
1
>
,
// B1BlockTransfer
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
...
@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
...
@@ -189,7 +191,7 @@ int main(int argc, char* argv[])
{
{
bool
do_verification
=
true
;
bool
do_verification
=
true
;
int
init_method
=
1
;
int
init_method
=
1
;
bool
time_kernel
=
fals
e
;
bool
time_kernel
=
tru
e
;
// GEMM shape
// GEMM shape
ck
::
index_t
M
=
1024
;
ck
::
index_t
M
=
1024
;
...
...
example/40_conv2d_fwd_quantization/CMakeLists.txt
View file @
50540084
if
(
DTYPES MATCHES
"int8"
OR NOT DEFINED DTYPES
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
list
(
APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
...
@@ -25,4 +26,5 @@ add_example_executable(example_conv2d_fwd_dl_bias_relu_perchannel_quantization_i
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8 conv2d_fwd_dl_bias_tanh_perlayer_quantization_int8.cpp
)
# Conv + bias + tanh perchannel quantization
# Conv + bias + tanh perchannel quantization
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
add_example_executable
(
example_conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8 conv2d_fwd_dl_bias_tanh_perchannel_quantization_int8.cpp
)
\ No newline at end of file
endif
()
\ No newline at end of file
example/50_put_element/put_element_fp16.cpp
View file @
50540084
...
@@ -34,9 +34,9 @@ int main()
...
@@ -34,9 +34,9 @@ int main()
int
N
=
1024
;
int
N
=
1024
;
Tensor
<
XDataType
>
x
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
XDataType
>
x
(
HostTensorDescriptor
{
N
});
Tensor
<
IndexDataType
>
indices
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
IndexDataType
>
indices
(
HostTensorDescriptor
{
N
});
Tensor
<
YDataType
>
y
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
YDataType
>
y
(
HostTensorDescriptor
{
N
});
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
-
1.0
,
1.0
});
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
-
1.0
,
1.0
});
for
(
int
i
=
0
;
i
<
N
;
++
i
)
for
(
int
i
=
0
;
i
<
N
;
++
i
)
...
@@ -72,7 +72,7 @@ int main()
...
@@ -72,7 +72,7 @@ int main()
bool
pass
=
true
;
bool
pass
=
true
;
if
(
do_verification
)
if
(
do_verification
)
{
{
Tensor
<
YDataType
>
y_host
(
HostTensorDescriptor
{
N
,
1
});
Tensor
<
YDataType
>
y_host
(
HostTensorDescriptor
{
N
});
for
(
int
i
=
0
;
i
<
N
;
++
i
)
for
(
int
i
=
0
;
i
<
N
;
++
i
)
{
{
...
...
include/ck/ck.hpp
View file @
50540084
...
@@ -118,8 +118,12 @@
...
@@ -118,8 +118,12 @@
// inline asm
// inline asm
#define CK_USE_AMD_INLINE_ASM 1
#define CK_USE_AMD_INLINE_ASM 1
// inner product (DLOP)
// inner product (V_MAC/V_FMAC)
#define CK_USE_AMD_INNER_PRODUCT_INLINE_ASM 1
#define CK_USE_AMD_V_MAC_INLINE_ASM 1
// V_DOT inline instructions, less efficient since they require adding
// `s_nop`s to avoid hazard
#define CK_USE_AMD_V_DOT_INLINE_ASM 0
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
#define CK_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
...
@@ -194,7 +198,7 @@
...
@@ -194,7 +198,7 @@
#define CK_WORKAROUND_SWDEV_388832 1
#define CK_WORKAROUND_SWDEV_388832 1
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
// workaround: Grouped Conv2d_bwd_data fails for already implemented instance
#define CK_WORKAROUND_
SWDEV_3318619 0
#define CK_WORKAROUND_
GITHUB_ISSUE_824 1
// flag to enable (1) or disable (0) the debugging output in some kernels
// flag to enable (1) or disable (0) the debugging output in some kernels
#define DEBUG_LOG 0
#define DEBUG_LOG 0
...
...
include/ck/host_utility/device_prop.hpp
View file @
50540084
...
@@ -51,4 +51,11 @@ inline std::string get_device_name()
...
@@ -51,4 +51,11 @@ inline std::string get_device_name()
return
name
;
return
name
;
}
}
inline
bool
is_xdl_supported
()
{
return
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
}
// namespace ck
}
// namespace ck
include/ck/host_utility/kernel_launch.hpp
View file @
50540084
...
@@ -73,3 +73,72 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -73,3 +73,72 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
return
0
;
return
0
;
#endif
#endif
}
}
template
<
typename
...
Args
,
typename
F
,
typename
PreProcessFunc
>
float
launch_and_time_kernel_with_preprocess
(
const
StreamConfig
&
stream_config
,
PreProcessFunc
preprocess
,
F
kernel
,
dim3
grid_dim
,
dim3
block_dim
,
std
::
size_t
lds_byte
,
Args
...
args
)
{
#if CK_TIME_KERNEL
if
(
stream_config
.
time_kernel_
)
{
#if DEBUG_LOG
printf
(
"%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d}
\n
"
,
__func__
,
grid_dim
.
x
,
grid_dim
.
y
,
grid_dim
.
z
,
block_dim
.
x
,
block_dim
.
y
,
block_dim
.
z
);
printf
(
"Warm up 1 time
\n
"
);
#endif
// warm up
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
const
int
nrepeat
=
10
;
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
hipEvent_t
start
,
stop
;
hip_check_error
(
hipEventCreate
(
&
start
));
hip_check_error
(
hipEventCreate
(
&
stop
));
hip_check_error
(
hipDeviceSynchronize
());
hip_check_error
(
hipEventRecord
(
start
,
stream_config
.
stream_id_
));
for
(
int
i
=
0
;
i
<
nrepeat
;
++
i
)
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
}
hip_check_error
(
hipEventRecord
(
stop
,
stream_config
.
stream_id_
));
hip_check_error
(
hipEventSynchronize
(
stop
));
float
total_time
=
0
;
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
return
total_time
/
nrepeat
;
}
else
{
preprocess
();
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
0
;
}
#else
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
return
0
;
#endif
}
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
50540084
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
...
@@ -94,6 +94,21 @@ struct ThreadGroupTensorSliceTransfer_v4r1
}
}
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
template
<
typename
SrcBuffer
,
index_t
ThreadScratchId
=
0
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcBuffer
&
src_buf
,
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1r2.hpp
0 → 100644
View file @
50540084
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v6r1r2.hpp"
namespace
ck
{
// this version does following things to avoid scratch memory issue
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
template
<
typename
ThreadGroup
,
typename
ElementwiseOperation
,
typename
SliceLengths
,
typename
ThreadClusterLengths
,
typename
ThreadClusterArrangeOrder
,
typename
SrcData
,
typename
DstData
,
typename
SrcDesc
,
typename
DstDesc
,
typename
DimAccessOrder
,
index_t
VectorDim
,
index_t
ScalarPerVector
,
bool
ThreadTransferSrcResetCoordinateAfterRun
,
bool
ThreadTransferDstResetCoordinateAfterRun
>
struct
ThreadGroupTensorSliceTransfer_v6r1r2
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
auto
thread_slice_lengths
=
SliceLengths
{}
/
ThreadClusterLengths
{};
using
Index
=
MultiIndex
<
nDim
>
;
__device__
constexpr
ThreadGroupTensorSliceTransfer_v6r1r2
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
,
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
,
const
ElementwiseOperation
&
element_op
)
:
threadwise_transfer_
(
src_desc
,
make_zero_multi_index
<
nDim
>
(),
dst_desc
,
make_zero_multi_index
<
nDim
>
(),
element_op
)
{
static_assert
(
nDim
==
remove_cvref_t
<
SrcDesc
>::
GetNumOfDimension
()
&&
nDim
==
remove_cvref_t
<
DstDesc
>::
GetNumOfDimension
()
&&
nDim
==
ThreadClusterLengths
::
Size
()
&&
nDim
==
ThreadClusterArrangeOrder
::
Size
()
&&
nDim
==
DimAccessOrder
::
Size
(),
"wrong! nDim not consistent"
);
static_assert
(
is_same
<
SliceLengths
,
decltype
(
thread_slice_lengths
*
ThreadClusterLengths
{})
>
{},
"wrong! threads should be mapped to cover entire slicing window"
);
static_assert
(
ThreadGroup
::
GetNumOfThread
()
>=
thread_cluster_desc_
.
GetElementSize
(),
"wrong! ThreadGroup::GetNumOfThread() too small"
);
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
template
<
typename
SrcBuffer
,
typename
DstBuffer
,
InMemoryDataOperationEnum
DstInMemOp
>
__device__
void
Run
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
DstBuffer
&
dst_buf
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
template
Run
<
SrcBuffer
,
DstBuffer
,
DstInMemOp
>(
src_desc
,
src_buf
,
dst_desc
,
dst_buf
);
}
}
__device__
void
MoveSrcSliceWindow
(
const
SrcDesc
&
src_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveSrcSliceWindow
(
src_desc
,
step
);
}
}
__device__
void
MoveDstSliceWindow
(
const
DstDesc
&
dst_desc
,
const
Index
&
step
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
MoveDstSliceWindow
(
dst_desc
,
step
);
}
}
__device__
void
SetSrcSliceOrigin
(
const
SrcDesc
&
src_desc
,
const
Index
&
src_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetSrcSliceOrigin
(
src_desc
,
src_block_slice_origin
+
thread_data_idx_begin
);
}
}
__device__
void
SetDstSliceOrigin
(
const
DstDesc
&
dst_desc
,
const
Index
&
dst_block_slice_origin
)
{
if
(
ThreadGroup
::
GetNumOfThread
()
==
thread_cluster_desc_
.
GetElementSize
()
or
ThreadGroup
::
GetThreadId
()
<
thread_cluster_desc_
.
GetElementSize
())
{
const
auto
thread_cluster_idx
=
thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
ThreadGroup
::
GetThreadId
()));
const
auto
thread_data_idx_begin
=
thread_cluster_idx
*
thread_slice_lengths
;
threadwise_transfer_
.
SetDstSliceOrigin
(
dst_desc
,
dst_block_slice_origin
+
thread_data_idx_begin
);
}
}
private:
static
constexpr
auto
thread_cluster_desc_
=
make_cluster_descriptor
(
ThreadClusterLengths
{},
ThreadClusterArrangeOrder
{});
using
ThreadwiseTransfer
=
ThreadwiseTensorSliceTransfer_v6r1r2
<
SrcData
,
DstData
,
SrcDesc
,
DstDesc
,
ElementwiseOperation
,
decltype
(
thread_slice_lengths
),
DimAccessOrder
,
VectorDim
,
ScalarPerVector
,
ThreadTransferSrcResetCoordinateAfterRun
,
ThreadTransferDstResetCoordinateAfterRun
>
;
ThreadwiseTransfer
threadwise_transfer_
;
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
50540084
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
...
@@ -19,8 +19,7 @@ getConvBackwardDataSpecializationString(const ConvolutionBackwardDataSpecializat
switch
(
s
)
switch
(
s
)
{
{
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Default
:
return
"Default"
;
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
case
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
:
return
"Filter1x1Stride1Pad0"
;
return
"FFilter1x1Stride1Pad0"
;
default:
return
"Unrecognized specialization!"
;
default:
return
"Unrecognized specialization!"
;
}
}
}
}
...
...
include/ck/tensor_operation/gpu/device/device_gemm_streamk.hpp
0 → 100644
View file @
50540084
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemmStreamK
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
NumSKBlocks
=
0
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmStreamKPtr
=
std
::
unique_ptr
<
DeviceGemmStreamK
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
50540084
...
@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
...
@@ -27,17 +27,19 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
ck
::
index_t
G
,
const
ck
::
index_t
G
,
ck
::
index_t
N
,
const
ck
::
index_t
N
,
ck
::
index_t
K
,
const
ck
::
index_t
K
,
ck
::
index_t
C
,
const
ck
::
index_t
C
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_right_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
InElementwiseOperation
in_element_op
,
InElementwiseOperation
in_element_op
,
WeiElementwiseOperation
wei_element_op
,
WeiElementwiseOperation
wei_element_op
,
OutElementwiseOperation
out_element_op
,
OutElementwiseOperation
out_element_op
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_xdl_cshuffle.hpp
View file @
50540084
...
@@ -840,9 +840,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
...
@@ -840,9 +840,7 @@ struct DeviceBatchedContractionMultipleD_Xdl_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
if
(
!
ck
::
is_xdl_supported
())
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
{
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp
View file @
50540084
...
@@ -571,6 +571,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
...
@@ -571,6 +571,11 @@ struct DeviceBatchedGemmEPermuteXdl : public DeviceBatchedGemmEPermute<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
ck
::
Tuple
<>
{},
ck
::
Tuple
<>
{},
...
...
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment