Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b2290854
"mmdet3d/models/vscode:/vscode.git/clone" did not exist on "878c9ff84fc199e20f39530966e2871953b7bacf"
Commit
b2290854
authored
May 27, 2022
by
rocking
Browse files
Merge commit '
3e6c2610
' into gemm_norm
parents
253f7ef2
3e6c2610
Changes
201
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1138 additions
and
94 deletions
+1138
-94
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
+116
-0
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
...quant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
+7
-2
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
+1
-1
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
+2
-1
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
...e/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
+2
-1
example/19_binary_elementwise/CMakeLists.txt
example/19_binary_elementwise/CMakeLists.txt
+2
-1
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
example/19_binary_elementwise/broadcast_add_2d_amn_bn.cpp
+13
-4
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
+123
-0
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+13
-4
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+13
-4
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
+47
-9
example/CMakeLists.txt
example/CMakeLists.txt
+1
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+50
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+7
-9
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+3
-4
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+2
-0
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+92
-51
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+51
-0
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
+586
-0
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
+7
-3
No files found.
example/13_pool2d_fwd/pool2d_fwd_fp32.cpp
0 → 100644
View file @
b2290854
#include <iostream>
#include <cstdlib>
#include "config.hpp"
#include "tensor_layout.hpp"
#include "reduction_enums.hpp"
#include "pool2d_fwd_common.hpp"
using
InDataType
=
float
;
using
OutDataType
=
float
;
using
AccDataType
=
float
;
using
IndexDataType
=
int32_t
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
#if 1
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
MAX
;
#else
static
constexpr
auto
ReduceOpId
=
ck
::
ReduceTensorOp
::
AVG
;
#endif
static
constexpr
bool
OutputIndex
=
false
;
static
constexpr
bool
PropagateNan
=
false
;
int
main
(
int
argc
,
char
*
argv
[])
{
using
namespace
ck
::
host_reduce
;
bool
do_verification
;
int
init_method
;
bool
time_kernel
;
// Pool shape
ck
::
index_t
N
=
128
;
ck
::
index_t
C
=
192
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
71
;
ck
::
index_t
Wi
=
71
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
in_left_pad_h
=
1
;
ck
::
index_t
in_left_pad_w
=
1
;
ck
::
index_t
in_right_pad_h
=
1
;
ck
::
index_t
in_right_pad_w
=
1
;
if
(
argc
==
1
)
{
do_verification
=
true
;
init_method
=
1
;
time_kernel
=
true
;
}
else
if
(
argc
==
4
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
}
else
if
(
argc
==
16
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
stoi
(
argv
[
3
]));
N
=
std
::
stoi
(
argv
[
4
]);
C
=
std
::
stoi
(
argv
[
5
]);
Y
=
std
::
stoi
(
argv
[
6
]);
X
=
std
::
stoi
(
argv
[
7
]);
Hi
=
std
::
stoi
(
argv
[
8
]);
Wi
=
std
::
stoi
(
argv
[
9
]);
window_stride_h
=
std
::
stoi
(
argv
[
10
]);
window_stride_w
=
std
::
stoi
(
argv
[
11
]);
in_left_pad_h
=
std
::
stoi
(
argv
[
12
]);
in_left_pad_w
=
std
::
stoi
(
argv
[
13
]);
in_right_pad_h
=
std
::
stoi
(
argv
[
14
]);
in_right_pad_w
=
std
::
stoi
(
argv
[
15
]);
}
else
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4 to 15: N, C, Y, X, Hi, Wi, Sy, Sx, LeftPy, LeftPx, RightPy, "
"RightPx
\n
"
);
exit
(
0
);
}
bool
pass
=
pool_test
<
InDataType
,
OutDataType
,
AccDataType
,
IndexDataType
,
InLayout
,
OutLayout
,
ReduceOpId
,
PropagateNan
,
OutputIndex
>
(
do_verification
,
init_method
,
time_kernel
,
N
,
C
,
Y
,
X
,
Hi
,
Wi
,
window_stride_h
,
window_stride_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_w
);
return
(
pass
?
0
:
1
);
}
example/14_gemm_xdl_requant_relu_requant/gemm_xdl_requant_relu_requant_int8.cpp
View file @
b2290854
...
...
@@ -100,8 +100,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm_Xdl_CShuffle
16
>
;
// index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
float
,
PassThrough
,
PassThrough
,
RequantReluRequant
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/15_grouped_gemm/grouped_gemm_xdl_fp16.cpp
View file @
b2290854
...
...
@@ -56,7 +56,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGroupedGemmXdl
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/16_gemm_reduce/gemm_reduce_xdl_max_fp16.cpp
View file @
b2290854
...
...
@@ -32,6 +32,7 @@ using CDataType = F16;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F64
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*>
;
using
AccDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -59,7 +60,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/16_gemm_reduce/gemm_reduce_xdl_mean_squaremean_fp16.cpp
View file @
b2290854
...
...
@@ -32,6 +32,7 @@ using CDataType = F16;
using
ReduceAccDataType
=
F32
;
using
DDataType
=
F32
;
using
DPtrsGlobal
=
ck
::
Tuple
<
DDataType
*
,
DDataType
*>
;
using
AccDataType
=
F32
;
using
ALayout
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
BLayout
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
...
@@ -70,7 +71,7 @@ using DeviceGemmReduceInstance = ck::tensor_operation::device::DeviceGemmReduce_
// clang-format on
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementOp
,
BElementOp
,
CElementOp
>
;
int
main
(
int
argc
,
char
*
argv
[])
{
...
...
example/19_binary_elementwise/CMakeLists.txt
View file @
b2290854
add_example_executable
(
example_broadcast_add_2d broadcast_add_2d.cpp
)
add_example_executable
(
example_broadcast_add_2d_amn_bn broadcast_add_2d_amn_bn.cpp
)
add_example_executable
(
example_broadcast_add_3d_am_bmnk broadcast_add_3d_am_bmnk.cpp
)
add_example_executable
(
example_elementwise_add_1d elementwise_add_1d.cpp
)
add_example_executable
(
example_elementwise_add_4d elementwise_add_4d.cpp
)
\ No newline at end of file
example/19_binary_elementwise/broadcast_add_2d.cpp
→
example/19_binary_elementwise/broadcast_add_2d
_amn_bn
.cpp
View file @
b2290854
...
...
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -100,7 +109,7 @@ int main()
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise
_2D
instance, exiting!"
);
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
...
...
@@ -123,7 +132,7 @@ int main()
0
>
(
host_c_m_n
,
a_m_n
,
b_n
,
M
,
N
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
,
"Error: Incorrect results
d1
"
,
1e-3
,
1e-3
);
c_m_n
.
mData
,
host_c_m_n
.
mData
,
"Error: Incorrect results
c
"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
...
...
example/19_binary_elementwise/broadcast_add_3d_am_bmnk.cpp
0 → 100644
View file @
b2290854
#include <iostream>
#include <cstdlib>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
3
,
8
,
1
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
>
void
host_broadcast3D_am_bmnk
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
const
std
::
vector
<
std
::
size_t
>&
shape
,
Functor
functor
)
{
using
ctype
=
ck
::
remove_reference_t
<
decltype
(
C
(
0
,
0
))
>
;
for
(
std
::
size_t
m
=
0
;
m
<
shape
[
0
];
++
m
)
for
(
std
::
size_t
n
=
0
;
n
<
shape
[
1
];
++
n
)
for
(
std
::
size_t
k
=
0
;
k
<
shape
[
2
];
++
k
)
{
ComputeDataType
a_val
=
static_cast
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
b_val
=
static_cast
<
ComputeDataType
>
(
B
(
m
,
n
,
k
));
ComputeDataType
c_val
=
0
;
functor
(
c_val
,
a_val
,
b_val
);
C
(
m
,
n
,
k
)
=
static_cast
<
ctype
>
(
c_val
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
std
::
vector
<
std
::
size_t
>
mnk
=
{
4
,
16
,
32
};
ck
::
index_t
M
=
mnk
[
0
];
Tensor
<
ABDataType
>
a_m
({
M
});
Tensor
<
ABDataType
>
b_m_n_k
(
mnk
);
Tensor
<
CDataType
>
c_m_n_k
(
mnk
);
a_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_m_n_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_device_buf
(
sizeof
(
ABDataType
)
*
a_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_m_n_k_device_buf
(
sizeof
(
ABDataType
)
*
b_m_n_k
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_k_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_k
.
mDesc
.
GetElementSpace
());
a_m_device_buf
.
ToDevice
(
a_m
.
mData
.
data
());
b_m_n_k_device_buf
.
ToDevice
(
b_m_n_k
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_n_k_device_buf
.
GetDeviceBuffer
(),
c_m_n_k_device_buf
.
GetDeviceBuffer
(),
std
::
vector
<
ck
::
index_t
>
{
mnk
.
begin
(),
mnk
.
end
()},
{
1
,
0
,
0
},
// broadcast A on second and third dimension
std
::
vector
<
ck
::
index_t
>
{
b_m_n_k
.
mDesc
.
GetStrides
().
begin
(),
b_m_n_k
.
mDesc
.
GetStrides
().
end
()},
std
::
vector
<
ck
::
index_t
>
{
c_m_n_k
.
mDesc
.
GetStrides
().
begin
(),
c_m_n_k
.
mDesc
.
GetStrides
().
end
()},
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_n_k_device_buf
.
FromDevice
(
c_m_n_k
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m_n_k
(
mnk
);
host_broadcast3D_am_bmnk
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
>
(
host_c_m_n_k
,
a_m
,
b_m_n_k
,
mnk
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n_k
.
mData
,
host_c_m_n_k
.
mData
,
"Error: Incorrect results c"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
b2290854
...
...
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -81,7 +90,7 @@ int main()
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise
_2D
instance, exiting!"
);
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
...
...
@@ -103,7 +112,7 @@ int main()
Add
>
(
host_c_m
,
a_m
,
b_m
,
M
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m
.
mData
,
host_c_m
.
mData
,
"Error: Incorrect results
d1
"
,
1e-3
,
1e-3
);
c_m
.
mData
,
host_c_m
.
mData
,
"Error: Incorrect results
c
"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
...
...
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
b2290854
...
...
@@ -19,8 +19,17 @@ using EltwiseComputeDataType = F32;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
>
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
ABDataType
,
ABDataType
,
CDataType
,
EltwiseComputeDataType
,
Add
,
4
,
8
,
8
,
8
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
...
...
@@ -83,7 +92,7 @@ int main()
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise
_2D
instance, exiting!"
);
"DeviceBinaryElementwise instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
...
...
@@ -105,7 +114,7 @@ int main()
Add
>
(
host_c
,
a
,
b
,
nchw
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c
.
mData
,
host_c
.
mData
,
"Error: Incorrect results
d1
"
,
1e-3
,
1e-3
);
ck
::
utils
::
check_err
(
c
.
mData
,
host_c
.
mData
,
"Error: Incorrect results
c
"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
...
...
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
View file @
b2290854
...
...
@@ -257,11 +257,11 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
1
:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
Wei
DataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
In
DataType
>
{
-
2
,
2
});
break
;
default:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
Wei
DataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
In
DataType
>
{
1
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
...
...
@@ -296,15 +296,53 @@ int main(int argc, char* argv[])
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
// alloc work space
size_t
bwd_weight_workspace_size
=
conv
->
GetWorkSpaceSize
(
argument
.
get
());
float
ave_time
=
0.
f
;
if
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
&&
split_k
>
1
)
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
wei_work_space_device_buf
.
SetZero
();
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
AccDataType
*>
(
wei_work_space_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
else
{
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
...
...
example/CMakeLists.txt
View file @
b2290854
include_directories
(
BEFORE
${
PROJECT_SOURCE_DIR
}
/include/ck
${
PROJECT_SOURCE_DIR
}
/include/ck/utility
${
PROJECT_SOURCE_DIR
}
/include/ck/host_utility
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor_description
${
PROJECT_SOURCE_DIR
}
/include/ck/tensor
${
PROJECT_SOURCE_DIR
}
/include/ck/problem_transform
...
...
include/ck/host_utility/device_prop.hpp
0 → 100644
View file @
b2290854
#pragma once
#include <string>
#include <map>
namespace
ck
{
inline
std
::
string
get_device_name
()
{
hipDeviceProp_t
props
{};
int
device
;
auto
status
=
hipGetDevice
(
&
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
status
=
hipGetDeviceProperties
(
&
props
,
device
);
if
(
status
!=
hipSuccess
)
{
return
std
::
string
();
}
const
std
::
string
raw_name
(
props
.
gcnArchName
);
// https://github.com/ROCmSoftwarePlatform/MIOpen/blob/8498875aef84878e04c1eabefdf6571514891086/src/target_properties.cpp#L40
static
std
::
map
<
std
::
string
,
std
::
string
>
device_name_map
=
{
{
"Ellesmere"
,
"gfx803"
},
{
"Baffin"
,
"gfx803"
},
{
"RacerX"
,
"gfx803"
},
{
"Polaris10"
,
"gfx803"
},
{
"Polaris11"
,
"gfx803"
},
{
"Tonga"
,
"gfx803"
},
{
"Fiji"
,
"gfx803"
},
{
"gfx800"
,
"gfx803"
},
{
"gfx802"
,
"gfx803"
},
{
"gfx804"
,
"gfx803"
},
{
"Vega10"
,
"gfx900"
},
{
"gfx901"
,
"gfx900"
},
{
"10.3.0 Sienna_Cichlid 18"
,
"gfx1030"
},
};
const
auto
name
=
raw_name
.
substr
(
0
,
raw_name
.
find
(
':'
));
// str.substr(0, npos) returns str.
auto
match
=
device_name_map
.
find
(
name
);
if
(
match
!=
device_name_map
.
end
())
return
match
->
second
;
return
name
;
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl
ops
_v2r3.hpp
→
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
b2290854
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
#pragma once
#include "common_header.hpp"
#include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v
2
.hpp"
#include "threadwise_contraction_dl
ops
.hpp"
#include "threadwise_tensor_slice_transfer_v
4r1
.hpp"
#include "threadwise_contraction_dl.hpp"
namespace
ck
{
...
...
@@ -41,7 +39,7 @@ template <index_t BlockSize,
typename
enable_if
<
ABlockDesc_BK0_BM_BK1
::
IsKnownAtCompileTime
()
&&
BBlockDesc_BK0_BN_BK1
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemmDl
ops
_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
struct
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
...
...
@@ -148,7 +146,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
MakeBBlockDescriptor_BK0_BN0_BN1_BK1
(
BBlockDesc_BK0_BN_BK1
{});
public:
__device__
BlockwiseGemmDl
ops
_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
()
__device__
BlockwiseGemmDl_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1
(
get_thread_local_1d_id
())},
a_thread_copy_
{
...
...
@@ -175,6 +173,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
"wrong!"
);
// TODO: remove this restriction
static_assert
(
BM0
==
2
,
"wrong"
);
static_assert
(
BM0
==
2
&&
BN0
==
2
,
"wrong"
);
}
...
...
@@ -226,7 +225,7 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
b_thread_desc_bk0_bn0_bn1_bk1_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_contraction
=
ThreadwiseContractionDl
ops
_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
<
FloatA
,
FloatB
,
FloatC
,
...
...
@@ -407,4 +406,3 @@ struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_B
};
}
// namespace ck
#endif
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
b2290854
...
...
@@ -75,14 +75,13 @@ struct BlockwiseTensorSliceTransfer_v5r1
}
}
template
<
typename
SrcBuffer
,
typename
SrcStepHacks
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
,
const
SrcStepHacks
&
src_step_hacks
)
template
<
typename
SrcBuffer
>
__device__
void
RunRead
(
const
SrcDesc
&
src_desc
,
const
SrcBuffer
&
src_buf
)
{
if
(
BlockSize
==
thread_cluster_desc_
.
GetElementSize
()
or
get_thread_local_1d_id
()
<
thread_cluster_desc_
.
GetElementSize
())
{
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
,
src_step_hacks
);
threadwise_transfer_
.
RunRead
(
src_desc
,
src_buf
);
}
}
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
b2290854
...
...
@@ -40,6 +40,8 @@ struct BaseOperator
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
b2290854
...
...
@@ -15,91 +15,107 @@ template <typename ADataType,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
index_t
Dim
,
index_t
ScalarPerVector
>
index_t
NDim
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOperator
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
0
>
static
auto
PadDescriptor_M
0
_1d
(
Desc_M
0
desc_m
0
,
index_t
gridSize
,
index_t
blockSize
)
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
m0
=
desc_m
0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m
0
_pad
=
transform_tensor_descriptor
(
desc_m
0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
const
auto
M
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
MPerThread
;
const
auto
pad
=
math
::
integer_least_multiple
(
M
,
loop_step
)
-
M
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
M
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m
0
_pad
;
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
0
(
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
,
index_t
gridSize
,
index_t
blockSize
)
static
auto
MakeDescriptor_M
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
stride
s
,
index_t
gridSize
,
index_t
blockSize
)
{
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
shape
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
[
I
];
},
Number
<
Dim
>
{});
auto
tupleOfShape
=
generate_tuple
([
&
](
auto
I
)
{
return
lengths
[
I
];
},
Number
<
N
Dim
>
{});
auto
tupleOfStride
=
generate_tuple
([
&
](
auto
I
)
{
return
stride
s
[
I
];
},
Number
<
N
Dim
>
{});
// nd desc - [s0, s1, s2, ...]
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleOfShape
,
tupleOfStride
);
// merge nd to 1d desc - [s0 * s1 * ...]
if
constexpr
(
Dim
>
1
)
if
constexpr
(
N
Dim
>
1
)
{
const
auto
desc_m
0
=
transform_tensor_descriptor
(
const
auto
desc_m
=
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
tupleOfShape
)),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
Dim
>
{})),
make_tuple
(
generate_sequence_v2
([
&
](
auto
I
)
{
return
I
;
},
Number
<
N
Dim
>
{})),
make_tuple
(
Sequence
<
0
>
{}));
return
PadDescriptor_M
0
_1d
(
desc_m
0
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
else
return
PadDescriptor_M
0
_1d
(
desc
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
AGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
BGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
ComputeDataType
,
GridDesc_M0
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
ElementwiseFunctor
,
ScalarPerVector
>
;
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
std
::
vector
<
index_t
>&
shape
,
const
std
::
vector
<
index_t
>&
stride
_a
,
const
std
::
vector
<
index_t
>&
stride
_b
,
const
std
::
vector
<
index_t
>&
stride
_c
,
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
a_
stride
s
,
const
std
::
vector
<
index_t
>&
b_
stride
s
,
const
std
::
vector
<
index_t
>&
c_
stride
s
,
ElementwiseFunctor
functor
)
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_c_
(
p_c
),
shape_
(
shape
),
lengths_
(
lengths
),
a_strides_
(
a_strides
),
b_strides_
(
b_strides
),
c_strides_
(
c_strides
),
functor_
(
functor
),
blockSize_
(
256
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride
_a
,
gridSize_
,
blockSize_
);
b_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride
_b
,
gridSize_
,
blockSize_
);
c_grid_desc_m
0
_
=
MakeDescriptor_M
0
(
shape
,
stride
_c
,
gridSize_
,
blockSize_
);
a_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
a_
stride
s
,
gridSize_
,
blockSize_
);
b_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
b_
stride
s
,
gridSize_
,
blockSize_
);
c_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
c_
stride
s
,
gridSize_
,
blockSize_
);
}
const
ADataType
*
p_a_
;
const
BDataType
*
p_b_
;
CDataType
*
p_c_
;
std
::
vector
<
int
>
shape_
;
GridDesc_M0
a_grid_desc_m0_
;
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
std
::
vector
<
int
>
lengths_
;
AGridDesc_M
a_grid_desc_m_
;
BGridDesc_M
b_grid_desc_m_
;
CGridDesc_M
c_grid_desc_m_
;
std
::
vector
<
index_t
>
a_strides_
;
std
::
vector
<
index_t
>
b_strides_
;
std
::
vector
<
index_t
>
c_strides_
;
ElementwiseFunctor
functor_
;
index_t
blockSize_
;
index_t
gridSize_
;
...
...
@@ -113,7 +129,9 @@ struct DeviceBinaryElementwise : public BaseOperator
ADataType
,
BDataType
,
CDataType
,
GridDesc_M0
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
...
...
@@ -124,9 +142,9 @@ struct DeviceBinaryElementwise : public BaseOperator
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m
0
_
,
arg
.
b_grid_desc_m
0
_
,
arg
.
c_grid_desc_m
0
_
,
arg
.
a_grid_desc_m_
,
arg
.
b_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
functor_
);
return
elapsed_time
;
}
...
...
@@ -146,7 +164,30 @@ struct DeviceBinaryElementwise : public BaseOperator
if
(
pArg
==
nullptr
)
return
false
;
if
(
pArg
->
shape_
.
back
()
%
ScalarPerVector
!=
0
)
if
(
pArg
->
lengths_
.
size
()
!=
NDim
)
return
false
;
if
(
pArg
->
lengths_
.
back
()
%
MPerThread
!=
0
)
return
false
;
auto
IsScalarPerVectorValid
=
[](
bool
isLastDimensionCoalesced
,
int
scalarPerVector
)
{
bool
ret
=
true
;
if
(
!
isLastDimensionCoalesced
)
ret
=
scalarPerVector
==
1
;
else
ret
=
MPerThread
%
scalarPerVector
==
0
;
return
ret
;
};
if
(
!
IsScalarPerVectorValid
(
pArg
->
a_strides_
.
back
()
==
1
,
AScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
b_strides_
.
back
()
==
1
,
BScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
c_strides_
.
back
()
==
1
,
CScalarPerVector
))
return
false
;
return
true
;
...
...
@@ -155,19 +196,19 @@ struct DeviceBinaryElementwise : public BaseOperator
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
index_t
>
shape
,
std
::
vector
<
index_t
>
stride
_a
,
std
::
vector
<
index_t
>
stride
_b
,
std
::
vector
<
index_t
>
stride
_c
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_
stride
s
,
std
::
vector
<
index_t
>
b_
stride
s
,
std
::
vector
<
index_t
>
c_
stride
s
,
ElementwiseFunctor
functor
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
shape
,
stride
_a
,
stride
_b
,
stride
_c
,
lengths
,
a_
stride
s
,
b_
stride
s
,
c_
stride
s
,
functor
);
}
...
...
@@ -180,7 +221,7 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"
ScalarPerVector = "
<<
ScalarPerVector
<<
"
MPerThread = "
<<
MPerThread
<<
">"
;
// clang-format on
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
b2290854
...
...
@@ -1175,6 +1175,57 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return
str
.
str
();
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
(
arg
.
k_batch_
>
1
)
{
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
sizeof
(
float
);
}
}
return
WorkSpaceSize
;
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
(
arg
.
k_batch_
>
1
)
{
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
arg
.
filter_spatial_lengths_
[
1
]
*
sizeof
(
float
);
}
}
return
WorkSpaceSize
;
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
(
arg
.
k_batch_
>
1
)
{
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
arg
.
filter_spatial_lengths_
[
1
]
*
arg
.
filter_spatial_lengths_
[
2
]
*
sizeof
(
float
);
}
}
return
WorkSpaceSize
;
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
final
{
return
GetWorkSpaceSize
<
NumDimSpatial
>
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
}
// namespace device
...
...
include/ck/tensor_operation/gpu/device/device_gemm_dl.hpp
0 → 100644
View file @
b2290854
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_gemm_xdl.hpp
View file @
b2290854
#ifndef DEVICE_GEMM_XDL_HPP
#define DEVICE_GEMM_XDL_HPP
#pragma once
#include <iostream>
#include <sstream>
...
...
@@ -12,6 +11,7 @@
#include "tensor_descriptor_helper.hpp"
#include "gridwise_gemm_xdlops_v2r3.hpp"
#include "gemm_specialization.hpp"
#include "device_prop.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -408,6 +408,11 @@ struct DeviceGemmXdl
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -515,4 +520,3 @@ struct DeviceGemmXdl
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
Prev
1
2
3
4
5
6
…
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment