Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ac76519a
Unverified
Commit
ac76519a
authored
Aug 10, 2023
by
Adam Osewski
Committed by
GitHub
Aug 10, 2023
Browse files
Merge branch 'develop' into aosewski/gemm_tile_loop
parents
a70c6283
578142db
Changes
174
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1336 additions
and
308 deletions
+1336
-308
example/42_groupnorm/CMakeLists.txt
example/42_groupnorm/CMakeLists.txt
+5
-3
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
+6
-2
example/44_elementwise_permute/CMakeLists.txt
example/44_elementwise_permute/CMakeLists.txt
+4
-2
example/46_gemm_add_multiply/CMakeLists.txt
example/46_gemm_add_multiply/CMakeLists.txt
+6
-2
example/48_pool3d_fwd/CMakeLists.txt
example/48_pool3d_fwd/CMakeLists.txt
+3
-2
example/49_maxpool2d_bwd/CMakeLists.txt
example/49_maxpool2d_bwd/CMakeLists.txt
+9
-3
example/50_put_element/CMakeLists.txt
example/50_put_element/CMakeLists.txt
+3
-1
example/51_avgpool3d_bwd/CMakeLists.txt
example/51_avgpool3d_bwd/CMakeLists.txt
+3
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp
example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp
+62
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
+147
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
+62
-0
example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp
example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp
+62
-0
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
...ude/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
+39
-0
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
...r_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
+6
-9
include/ck/tensor_operation/gpu/device/device_put_element.hpp
...ude/ck/tensor_operation/gpu/device/device_put_element.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
...tion/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
+575
-0
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
...ation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
+12
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+6
-3
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
...impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
+78
-81
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
...vice/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
+247
-194
No files found.
example/42_groupnorm/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp
)
add_example_executable
(
example_groupnorm_sigmoid_mul_fp16 groupnorm_sigmoid_mul_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
add_example_executable
(
example_groupnorm_splitk_fp16 groupnorm_splitk_fp16.cpp
)
add_example_executable
(
example_groupnorm_swish_fp16 groupnorm_swish_fp16.cpp
)
endif
()
example/43_splitk_gemm_bias_e_permute/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp16 splitk_gemm_bias_e_permute_xdl_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_splitk_gemm_bias_e_permute_xdl_fp32 splitk_gemm_bias_e_permute_xdl_fp32.cpp
)
endif
()
example/44_elementwise_permute/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp
)
add_example_executable
(
example_elementwise_permute_4D_fp16 elementwise_permute_4D_fp16.cpp
)
add_example_executable
(
example_elementwise_permute_4D_fp16_2d elementwise_permute_4D_fp16_2d.cpp
)
endif
()
example/46_gemm_add_multiply/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp
)
if
(
DL_KERNELS
)
add_example_executable
(
example_gemm_add_multiply_dl_fp16 gemm_add_multiply_dl_fp16.cpp
)
endif
()
add_example_executable
(
example_gemm_add_multiply_xdl_fp16 gemm_add_multiply_xdl_fp16.cpp
)
endif
()
example/48_pool3d_fwd/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_pool3d_fwd_fp16 pool3d_fwd_fp16.cpp
)
endif
()
example/49_maxpool2d_bwd/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp
)
if
(
DTYPES MATCHES
"bf16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp
)
add_example_executable
(
example_maxpool2d_bwd_bf16 maxpool2d_bwd_bf16.cpp
)
add_example_executable
(
example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp16 maxpool2d_bwd_fp16.cpp
)
endif
()
if
(
DTYPES MATCHES
"fp32"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp
)
endif
()
example/50_put_element/CMakeLists.txt
View file @
ac76519a
add_example_executable
(
example_put_element_fp16 put_element_fp16.cpp
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
add_example_executable
(
example_put_element_fp16 put_element_fp16.cpp
)
endif
()
example/51_avgpool3d_bwd/CMakeLists.txt
0 → 100644
View file @
ac76519a
add_example_executable
(
example_avgpool3d_bwd_bf16 avgpool3d_bwd_bf16.cpp
)
add_example_executable
(
example_avgpool3d_bwd_fp16 avgpool3d_bwd_fp16.cpp
)
add_example_executable
(
example_avgpool3d_bwd_fp32 avgpool3d_bwd_fp32.cpp
)
example/51_avgpool3d_bwd/avgpool3d_bwd_bf16.cpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using
DOutDataType
=
ck
::
bhalf_t
;
using
DInDataType
=
ck
::
bhalf_t
;
using
ComputeDataType
=
float
;
#if 1
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
#else
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
#endif
using
DevicePoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceAvgPool3dBwd_NDHWC_NDHWC
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
int
main
()
{
std
::
vector
<
ck
::
index_t
>
window_lengths
=
{
5
,
5
,
5
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
window_dilations
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
=
{
0
,
0
,
0
};
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
=
{
0
,
0
,
0
};
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
16
;
ck
::
index_t
Di
=
40
;
ck
::
index_t
Hi
=
40
;
ck
::
index_t
Wi
=
40
;
pool3d_bwd_test
<
DevicePoolBwdInstance
,
DOutDataType
,
DInDataType
,
DOutLayout
,
DInLayout
>
(
true
,
false
,
N
,
C
,
Di
,
Hi
,
Wi
,
window_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
);
}
example/51_avgpool3d_bwd/avgpool3d_bwd_common.hpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_avgpool_bwd.hpp"
template
<
typename
TensorLayout
>
std
::
vector
<
ck
::
index_t
>
f_tensor_strides_ncdhw
(
ck
::
index_t
N_
,
ck
::
index_t
C_
,
ck
::
index_t
D
,
ck
::
index_t
H
,
ck
::
index_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
(
void
)
N_
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
return
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
};
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
return
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
};
};
template
<
typename
TensorLayout
>
HostTensorDescriptor
f_host_tensor_descriptor
(
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
D
,
std
::
size_t
H
,
std
::
size_t
W
,
TensorLayout
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCDHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
C_
*
D
*
H
*
W
,
D
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NDHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
D
,
H
,
W
},
{
D
*
C_
*
H
*
W
,
1
_uz
,
C_
*
H
*
W
,
W
*
C_
,
C_
});
}
};
template
<
typename
DevicePoolBwdInstance
,
typename
DOutDataType
,
typename
DInDataType
,
typename
DOutLayout
,
typename
DInLayout
>
bool
pool3d_bwd_test
(
bool
do_verification
,
bool
time_kernel
,
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
Di
,
ck
::
index_t
Hi
,
ck
::
index_t
Wi
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
,
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
)
{
auto
OutSpatialLength
=
[
&
](
auto
InSpatialLength
,
int
index
)
{
ck
::
index_t
left_pad
=
dinput_left_pads
[
index
];
ck
::
index_t
right_pad
=
dinput_right_pads
[
index
];
ck
::
index_t
window_len
=
window_lengths
[
index
];
ck
::
index_t
stride
=
window_strides
[
index
];
ck
::
index_t
dilation
=
window_dilations
[
index
];
ck
::
index_t
eff
=
(
window_len
-
1
)
*
dilation
+
1
;
return
(
InSpatialLength
+
left_pad
+
right_pad
-
eff
)
/
stride
+
1
;
};
ck
::
index_t
Do
=
OutSpatialLength
(
Di
,
0
);
ck
::
index_t
Ho
=
OutSpatialLength
(
Hi
,
1
);
ck
::
index_t
Wo
=
OutSpatialLength
(
Wi
,
2
);
Tensor
<
DOutDataType
>
dout
(
f_host_tensor_descriptor
(
N
,
C
,
Do
,
Ho
,
Wo
,
DOutLayout
{}));
Tensor
<
DInDataType
>
din_dev
(
f_host_tensor_descriptor
(
N
,
C
,
Di
,
Hi
,
Wi
,
DInLayout
{}));
Tensor
<
DInDataType
>
din_host
(
f_host_tensor_descriptor
(
N
,
C
,
Di
,
Hi
,
Wi
,
DInLayout
{}));
std
::
cout
<<
"dout: "
<<
dout
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"din_host: "
<<
din_host
.
mDesc
<<
std
::
endl
;
dout
.
GenerateTensorValue
(
GeneratorTensor_3
<
DOutDataType
>
{
0.0
,
1.0
});
DeviceMem
dout_device_buf
(
sizeof
(
DOutDataType
)
*
dout
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
din_device_buf
(
sizeof
(
DInDataType
)
*
din_dev
.
mDesc
.
GetElementSpaceSize
());
dout_device_buf
.
ToDevice
(
dout
.
mData
.
data
());
din_device_buf
.
SetZero
();
auto
pool
=
DevicePoolBwdInstance
{};
auto
invoker_ptr
=
pool
.
MakeInvokerPointer
();
auto
argument_ptr
=
pool
.
MakeArgumentPointer
(
static_cast
<
DOutDataType
*>
(
dout_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DInDataType
*>
(
din_device_buf
.
GetDeviceBuffer
()),
{
N
,
C
,
Do
,
Ho
,
Wo
},
{
N
,
C
,
Di
,
Hi
,
Wi
},
f_tensor_strides_ncdhw
(
N
,
C
,
Do
,
Ho
,
Wo
,
DOutLayout
{}),
f_tensor_strides_ncdhw
(
N
,
C
,
Di
,
Hi
,
Wi
,
DInLayout
{}),
window_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
);
if
(
!
pool
.
IsSupportedArgument
(
argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_op with the specified compilation parameters does "
"not support this problem"
);
}
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
auto
ref_pool
=
ck
::
tensor_operation
::
host
::
ReferenceAvgPoolBwd
<
3
,
DInDataType
,
DOutDataType
>
();
auto
ref_invoker
=
ref_pool
.
MakeInvoker
();
auto
ref_argument
=
ref_pool
.
MakeArgument
(
din_host
,
dout
,
window_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
);
ref_invoker
.
Run
(
ref_argument
);
din_device_buf
.
FromDevice
(
din_dev
.
mData
.
data
());
pass
=
ck
::
utils
::
check_err
(
din_dev
,
din_host
);
}
return
pass
;
}
example/51_avgpool3d_bwd/avgpool3d_bwd_fp16.cpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using
DOutDataType
=
ck
::
half_t
;
using
DInDataType
=
ck
::
half_t
;
using
ComputeDataType
=
float
;
#if 1
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
#else
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
#endif
using
DevicePoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceAvgPool3dBwd_NDHWC_NDHWC
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
int
main
()
{
std
::
vector
<
ck
::
index_t
>
window_lengths
=
{
5
,
5
,
5
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
window_dilations
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
=
{
0
,
0
,
0
};
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
=
{
0
,
0
,
0
};
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
16
;
ck
::
index_t
Di
=
40
;
ck
::
index_t
Hi
=
40
;
ck
::
index_t
Wi
=
40
;
pool3d_bwd_test
<
DevicePoolBwdInstance
,
DOutDataType
,
DInDataType
,
DOutLayout
,
DInLayout
>
(
true
,
false
,
N
,
C
,
Di
,
Hi
,
Wi
,
window_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
);
}
example/51_avgpool3d_bwd/avgpool3d_bwd_fp32.cpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp"
#include "avgpool3d_bwd_common.hpp"
using
DOutDataType
=
float
;
using
DInDataType
=
float
;
using
ComputeDataType
=
float
;
#if 1
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NDHWC
;
#else
using
DOutLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
using
DInLayout
=
ck
::
tensor_layout
::
convolution
::
NCDHW
;
#endif
using
DevicePoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DeviceAvgPool3dBwd_NDHWC_NDHWC
<
DOutDataType
,
DInDataType
,
ComputeDataType
,
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
1
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
int
main
()
{
std
::
vector
<
ck
::
index_t
>
window_lengths
=
{
5
,
5
,
5
};
std
::
vector
<
ck
::
index_t
>
window_strides
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
window_dilations
=
{
2
,
2
,
2
};
std
::
vector
<
ck
::
index_t
>
dinput_left_pads
=
{
0
,
0
,
0
};
std
::
vector
<
ck
::
index_t
>
dinput_right_pads
=
{
0
,
0
,
0
};
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
16
;
ck
::
index_t
Di
=
40
;
ck
::
index_t
Hi
=
40
;
ck
::
index_t
Wi
=
40
;
pool3d_bwd_test
<
DevicePoolBwdInstance
,
DOutDataType
,
DInDataType
,
DOutLayout
,
DInLayout
>
(
true
,
false
,
N
,
C
,
Di
,
Hi
,
Wi
,
window_lengths
,
window_strides
,
window_dilations
,
dinput_left_pads
,
dinput_right_pads
);
}
include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp
0 → 100644
View file @
ac76519a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
NDimSpatial
,
typename
DOutDataType
,
typename
DInDataType
,
typename
DOutLayout
,
typename
DInLayout
>
struct
DeviceAvgPoolBwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_dout
,
void
*
p_din
,
std
::
vector
<
ck
::
index_t
>
dout_n_k_wos_lengths
,
std
::
vector
<
ck
::
index_t
>
dout_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
din_n_k_wos_length
,
std
::
vector
<
ck
::
index_t
>
din_n_k_wos_strides
,
std
::
vector
<
ck
::
index_t
>
window_k_c_xs_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
,
std
::
vector
<
ck
::
index_t
>
window_dilations
,
std
::
vector
<
ck
::
index_t
>
input_left_pads
,
std
::
vector
<
ck
::
index_t
>
input_right_pads
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp
View file @
ac76519a
...
@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
...
@@ -27,15 +27,12 @@ struct DeviceGroupedConvBwdWeight : public BaseOperator
MakeArgumentPointer
(
const
void
*
p_in
,
MakeArgumentPointer
(
const
void
*
p_in
,
void
*
p_wei
,
void
*
p_wei
,
const
void
*
p_out
,
const
void
*
p_out
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/device_put_element.hpp
View file @
ac76519a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
2
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp
0 → 100644
View file @
ac76519a
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp
View file @
ac76519a
...
@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -123,7 +123,8 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -284,8 +285,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
CDataType
,
true
>
;
ADataType
,
BDataType
,
CDataType
,
true
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
...
@@ -357,8 +361,11 @@ struct DeviceCGemm_4Gemm_Xdl_CShuffle
}
}
else
else
{
{
const
auto
kernel
=
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
CDataType
,
false
>
;
ADataType
,
BDataType
,
CDataType
,
false
>
;
ave_time
+=
launch_and_time_kernel
(
stream_config
,
ave_time
+=
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
ac76519a
...
@@ -65,7 +65,8 @@ template <typename ALayout,
...
@@ -65,7 +65,8 @@ template <typename ALayout,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
LoopScheduler
LoopSched
=
make_default_loop_scheduler
(),
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
,
typename
ComputeType
=
CDataType
>
struct
DeviceGemm_Xdl_CShuffle
:
public
DeviceGemm
<
ALayout
,
struct
DeviceGemm_Xdl_CShuffle
:
public
DeviceGemm
<
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
...
@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -87,7 +88,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
ALayout
,
ALayout
,
BLayout
,
BLayout
,
CLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
BDataType
,
GemmAccDataType
,
GemmAccDataType
,
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
...
@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
...
@@ -128,7 +130,8 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
LoopSched
,
PipelineVer
>
;
PipelineVer
,
ComputeType
>
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
using
Argument
=
typename
GridwiseGemm
::
Argument
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_gnwc_gkxc_gnwk_dl.hpp
View file @
ac76519a
...
@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -784,15 +784,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
Argument
(
const
InDataType
*
p_in_grid
,
Argument
(
const
InDataType
*
p_in_grid
,
WeiDataType
*
p_wei_grid
,
WeiDataType
*
p_wei_grid
,
const
OutDataType
*
p_out_grid
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*a_g_n_c_wis_strides*/
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*b_g_k_c_xs_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
/*e_g_n_k_wos_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*input_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
/*output_strides*/
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -812,27 +809,38 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
a_element_op_
{
out_element_op
},
a_element_op_
{
out_element_op
},
b_element_op_
{
wei_element_op
},
b_element_op_
{
wei_element_op
},
c_element_op_
{
in_element_op
},
c_element_op_
{
in_element_op
},
Conv_G_
{
G
},
Conv_G_
{
a_g_n_c_wis_lengths
[
0
]
},
Conv_N_
{
N
},
Conv_N_
{
a_g_n_c_wis_lengths
[
1
]
},
Conv_K_
{
K
},
Conv_K_
{
b_g_k_c_xs_lengths
[
1
]
},
Conv_C_
{
C
},
Conv_C_
{
a_g_n_c_wis_lengths
[
2
]
},
input_spatial_lengths_
{
input_spatial_lengths
},
input_spatial_lengths_
{},
filter_spatial_lengths_
{
filter_spatial_lengths
},
filter_spatial_lengths_
{},
output_spatial_lengths_
{
output_spatial_lengths
},
output_spatial_lengths_
{},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
},
input_right_pads_
{
input_right_pads
},
k_batch_
{
split_k
}
k_batch_
{
split_k
}
{
{
constexpr
index_t
spatial_offset
=
3
;
std
::
copy
(
begin
(
a_g_n_c_wis_lengths
)
+
spatial_offset
,
end
(
a_g_n_c_wis_lengths
),
begin
(
input_spatial_lengths_
));
std
::
copy
(
begin
(
b_g_k_c_xs_lengths
)
+
spatial_offset
,
end
(
b_g_k_c_xs_lengths
),
begin
(
filter_spatial_lengths_
));
std
::
copy
(
begin
(
e_g_n_k_wos_lengths
)
+
spatial_offset
,
end
(
e_g_n_k_wos_lengths
),
begin
(
output_spatial_lengths_
));
const
auto
descs
=
const
auto
descs
=
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
DeviceOp
::
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N
<
NDimSpatial
>
(
N
,
Conv_N_
,
K
,
Conv_K_
,
C
,
C
onv_C_
,
input_spatial_lengths
,
input_spatial_lengths
_
,
filter_spatial_lengths
,
filter_spatial_lengths
_
,
output_spatial_lengths
,
output_spatial_lengths
_
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -856,21 +864,21 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
// A/B/C Batch Stride
// A/B/C Batch Stride
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
N
*
K
*
Conv_N_
*
Conv_K_
*
std
::
accumulate
(
begin
(
output_spatial_lengths
),
std
::
accumulate
(
begin
(
output_spatial_lengths
_
),
end
(
output_spatial_lengths
),
end
(
output_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
N
*
C
*
Conv_N_
*
Conv_C_
*
std
::
accumulate
(
begin
(
input_spatial_lengths
),
std
::
accumulate
(
begin
(
input_spatial_lengths
_
),
end
(
input_spatial_lengths
),
end
(
input_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
compute_ptr_offset_of_batch_
.
BatchStrideC_
=
K
*
C
*
Conv_K_
*
Conv_C_
*
std
::
accumulate
(
begin
(
filter_spatial_lengths
),
std
::
accumulate
(
begin
(
filter_spatial_lengths
_
),
end
(
filter_spatial_lengths
),
end
(
filter_spatial_lengths
_
),
index_t
{
1
},
index_t
{
1
},
std
::
multiplies
<>
{});
std
::
multiplies
<>
{});
}
}
...
@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -904,9 +912,9 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
const
index_t
Conv_K_
;
const
index_t
Conv_K_
;
const
index_t
Conv_C_
;
const
index_t
Conv_C_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
input_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
input_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
filter_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
filter_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
&
output_spatial_lengths_
;
std
::
array
<
ck
::
index_t
,
NDimSpatial
>
output_spatial_lengths_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads_
;
...
@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1110,39 +1118,34 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
return
IsSupportedArgument
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
}
static
auto
MakeArgument
(
const
InDataType
*
p_in_grid
,
static
auto
WeiDataType
*
p_wei_grid
,
MakeArgument
(
const
InDataType
*
p_in_grid
,
const
OutDataType
*
p_out_grid
,
WeiDataType
*
p_wei_grid
,
const
ck
::
index_t
G
,
const
OutDataType
*
p_out_grid
,
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
InElementwiseOperation
in_element_op
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_right_pads
,
WeiElementwiseOperation
wei_element_op
,
InElementwiseOperation
in_element_op
,
OutElementwiseOperation
out_element_op
,
WeiElementwiseOperation
wei_element_op
,
ck
::
index_t
split_k
)
OutElementwiseOperation
out_element_op
,
ck
::
index_t
split_k
)
{
{
return
Argument
{
p_in_grid
,
return
Argument
{
p_in_grid
,
p_wei_grid
,
p_wei_grid
,
p_out_grid
,
p_out_grid
,
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1159,15 +1162,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
MakeArgumentPointer
(
const
void
*
p_in_grid
,
MakeArgumentPointer
(
const
void
*
p_in_grid
,
void
*
p_wei_grid
,
void
*
p_wei_grid
,
const
void
*
p_out_grid
,
const
void
*
p_out_grid
,
const
ck
::
index_t
G
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
// input
const
ck
::
index_t
N
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
index_t
K
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
// weight
const
ck
::
index_t
C
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
// output
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
filter_spatial_lengths
,
const
std
::
array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
output_spatial_lengths
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
input_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
+
3
>&
output_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
const
std
::
array
<
ck
::
index_t
,
NDimSpatial
>&
input_left_pads
,
...
@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
...
@@ -1180,15 +1180,12 @@ struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Dl
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_in_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
WeiDataType
*>
(
p_wei_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
static_cast
<
const
OutDataType
*>
(
p_out_grid
),
G
,
a_g_n_c_wis_lengths
,
// input
N
,
a_g_n_c_wis_strides
,
K
,
b_g_k_c_xs_lengths
,
// weight
C
,
b_g_k_c_xs_strides
,
input_spatial_lengths
,
e_g_n_k_wos_lengths
,
// output
filter_spatial_lengths
,
e_g_n_k_wos_strides
,
output_spatial_lengths
,
input_strides
,
output_strides
,
conv_filter_strides
,
conv_filter_strides
,
conv_filter_dilations
,
conv_filter_dilations
,
input_left_pads
,
input_left_pads
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_xdl_cshuffle.hpp
View file @
ac76519a
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
6
7
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment