Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1d56022b
Commit
1d56022b
authored
Jun 01, 2023
by
rocking
Browse files
Add maxpool f32 kernel and example
parent
ac9e01e2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
742 additions
and
0 deletions
+742
-0
example/49_maxpool2d_bwd/CMakeLists.txt
example/49_maxpool2d_bwd/CMakeLists.txt
+1
-0
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
+231
-0
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
+75
-0
include/ck/tensor_operation/gpu/device/device_put_element.hpp
...ude/ck/tensor_operation/gpu/device/device_put_element.hpp
+36
-0
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
...sor_operation/gpu/device/impl/device_put_element_impl.hpp
+152
-0
include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp
+146
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
.../reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
+101
-0
No files found.
example/49_maxpool2d_bwd/CMakeLists.txt
0 → 100644
View file @
1d56022b
add_example_executable
(
example_maxpool2d_bwd_fp32 maxpool2d_bwd_fp32.cpp
)
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include "ck/ck.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_pool_fwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp"
template
<
typename
InDataType
,
typename
OutDataType
,
typename
IndexDataType
,
typename
ComputeDataType
,
typename
DInDataType
,
typename
DOutDataType
,
typename
InLayout
,
typename
OutLayout
,
ck
::
ReduceTensorOp
ReduceOpId
,
bool
PropagateNan
,
ck
::
InMemoryDataOperationEnum
Memop
>
bool
maxpool_bwd_test
(
bool
do_verification
,
bool
time_kernel
,
ck
::
index_t
N
,
ck
::
index_t
C
,
ck
::
index_t
Y
,
ck
::
index_t
X
,
ck
::
index_t
Hi
,
ck
::
index_t
Wi
,
ck
::
index_t
window_stride_h
,
ck
::
index_t
window_stride_w
,
ck
::
index_t
in_left_pad_h
,
ck
::
index_t
in_left_pad_w
,
ck
::
index_t
in_right_pad_h
,
ck
::
index_t
in_right_pad_w
)
{
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
DevicePoolFwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
<
InDataType
,
// InDataType
OutDataType
,
// OutDataType
IndexDataType
,
// IndexDataType
ComputeDataType
,
// ComputeDataType
ReduceOpId
,
true
,
// OutputIndex
64
,
// BlockSize
64
,
// ReduceMThreadClusterSize
1
,
// ReduceKThreadClusterSize
4
,
// ReduceMThreadSliceSize
1
,
// ReduceKThreadSliceSize
1
>
;
// InSrcOutDstVectorSize
using
DeviceMaxPoolBwdInstance
=
ck
::
tensor_operation
::
device
::
DevicePutElementImpl
<
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
,
Memop
,
4
>
;
const
ck
::
index_t
Ho
=
(
Hi
+
in_left_pad_h
+
in_right_pad_h
-
Y
)
/
window_stride_h
+
1
;
const
ck
::
index_t
Wo
=
(
Wi
+
in_left_pad_w
+
in_right_pad_w
-
X
)
/
window_stride_w
+
1
;
const
std
::
vector
<
ck
::
index_t
>
window_spatial_lengths
{
Y
,
X
};
const
std
::
vector
<
ck
::
index_t
>
window_strides
{
window_stride_h
,
window_stride_w
};
const
std
::
vector
<
ck
::
index_t
>
input_left_pads
{
in_left_pad_h
,
in_left_pad_w
};
const
std
::
vector
<
ck
::
index_t
>
input_right_pads
{
in_right_pad_h
,
in_right_pad_w
};
// tensor layout
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
N_
,
std
::
size_t
C_
,
std
::
size_t
H
,
std
::
size_t
W
,
auto
layout
)
{
using
namespace
ck
::
literals
;
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NCHW
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
H
,
W
},
{
C_
*
H
*
W
,
H
*
W
,
W
,
1
_uz
});
}
else
if
constexpr
(
ck
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
convolution
::
NHWC
>::
value
)
{
return
HostTensorDescriptor
({
N_
,
C_
,
H
,
W
},
{
C_
*
H
*
W
,
1
_uz
,
W
*
C_
,
C_
});
}
};
// in
Tensor
<
InDataType
>
in_n_c_hi_wi
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
,
InLayout
{}));
// out
Tensor
<
OutDataType
>
out_n_c_ho_wo_host
(
f_host_tensor_descriptor
(
N
,
C
,
Ho
,
Wo
,
OutLayout
{}));
Tensor
<
OutDataType
>
out_n_c_ho_wo_device
(
f_host_tensor_descriptor
(
N
,
C
,
Ho
,
Wo
,
OutLayout
{}));
// indices
Tensor
<
IndexDataType
>
indices_n_c_ho_wo_device
(
f_host_tensor_descriptor
(
N
,
C
,
Ho
,
Wo
,
OutLayout
{}));
Tensor
<
IndexDataType
>
indices_n_c_ho_wo_host
(
f_host_tensor_descriptor
(
N
,
C
,
Ho
,
Wo
,
OutLayout
{}));
// dout
Tensor
<
DOutDataType
>
dout_n_c_ho_wo
(
f_host_tensor_descriptor
(
N
,
C
,
Ho
,
Wo
,
OutLayout
{}));
// din
Tensor
<
DInDataType
>
din_n_c_hi_wi_host
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
,
InLayout
{}));
Tensor
<
DInDataType
>
din_n_c_hi_wi_device
(
f_host_tensor_descriptor
(
N
,
C
,
Hi
,
Wi
,
InLayout
{}));
std
::
cout
<<
"in_n_c_hi_wi: "
<<
in_n_c_hi_wi
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"out_n_c_ho_wo: "
<<
out_n_c_ho_wo_host
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"indices_n_c_ho_wo: "
<<
indices_n_c_ho_wo_host
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"dout_n_c_ho_wo: "
<<
dout_n_c_ho_wo
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"din_n_c_hi_wi: "
<<
din_n_c_hi_wi_host
.
mDesc
<<
std
::
endl
;
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_3
<
InDataType
>
{
-
1.0
,
1.0
});
dout_n_c_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_3
<
DOutDataType
>
{
-
1.0
,
1.0
});
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
out_device_buf
(
sizeof
(
OutDataType
)
*
out_n_c_ho_wo_device
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
indices_device_buf
(
sizeof
(
IndexDataType
)
*
indices_n_c_ho_wo_device
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dout_device_buf
(
sizeof
(
DOutDataType
)
*
dout_n_c_ho_wo
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
din_device_buf
(
sizeof
(
DInDataType
)
*
din_n_c_hi_wi_device
.
mDesc
.
GetElementSpaceSize
());
in_device_buf
.
ToDevice
(
in_n_c_hi_wi
.
mData
.
data
());
dout_device_buf
.
ToDevice
(
dout_n_c_ho_wo
.
mData
.
data
());
din_device_buf
.
SetZero
();
auto
pool_fwd
=
DevicePoolFwdInstance
{};
auto
pool_fwd_invoker_ptr
=
pool_fwd
.
MakeInvokerPointer
();
auto
pool_fwd_argument_ptr
=
pool_fwd
.
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
indices_device_buf
.
GetDeviceBuffer
()),
{
N
,
C
,
Hi
,
Wi
},
{
Y
,
X
},
{
N
,
C
,
Ho
,
Wo
},
{
C
*
Hi
*
Wi
,
1
,
Wi
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
{
C
*
Ho
*
Wo
,
1
,
Wo
*
C
,
C
},
window_strides
,
input_left_pads
,
input_right_pads
,
{
2
,
3
});
if
(
!
pool_fwd
.
IsSupportedArgument
(
pool_fwd_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! pool_fwd with the specified compilation parameters does "
"not support this problem"
);
}
float
ave_time_fwd
=
pool_fwd_invoker_ptr
->
Run
(
pool_fwd_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
auto
pool_bwd
=
DeviceMaxPoolBwdInstance
{};
auto
pool_bwd_invoker_ptr
=
pool_bwd
.
MakeInvokerPointer
();
auto
pool_bwd_argument_ptr
=
pool_bwd
.
MakeArgumentPointer
(
static_cast
<
DOutDataType
*>
(
dout_device_buf
.
GetDeviceBuffer
()),
static_cast
<
IndexDataType
*>
(
indices_device_buf
.
GetDeviceBuffer
()),
static_cast
<
DInDataType
*>
(
din_device_buf
.
GetDeviceBuffer
()),
dout_n_c_ho_wo
.
mDesc
.
GetElementSpaceSize
(),
din_n_c_hi_wi_device
.
mDesc
.
GetElementSpaceSize
(),
PassThrough
{});
if
(
!
pool_bwd
.
IsSupportedArgument
(
pool_bwd_argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! pool_bwd with the specified compilation parameters does "
"not support this problem"
);
}
float
ave_time_bwd
=
pool_bwd_invoker_ptr
->
Run
(
pool_bwd_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Pool fwd perf: "
<<
ave_time_fwd
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
"Pool bwd perf: "
<<
ave_time_bwd
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
using
ReferencePoolingFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferencePoolingFwd
<
4
,
2
,
InDataType
,
OutDataType
,
ComputeDataType
,
IndexDataType
,
ReduceOpId
,
PropagateNan
,
true
>
;
auto
ref_pooling_fwd
=
ReferencePoolingFwdInstance
{};
auto
ref_pooling_fwd_invoker
=
ref_pooling_fwd
.
MakeInvoker
();
auto
ref_pooling_fwd_argument
=
ref_pooling_fwd
.
MakeArgument
(
in_n_c_hi_wi
,
out_n_c_ho_wo_host
,
indices_n_c_ho_wo_host
,
window_spatial_lengths
,
window_strides
,
input_left_pads
,
input_right_pads
);
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
using
ReferencePoolingBwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceMaxPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
auto
ref_pooling_bwd
=
ReferencePoolingBwdInstance
{};
auto
ref_pooling_bwd_invoker
=
ref_pooling_bwd
.
MakeInvoker
();
auto
ref_pooling_bwd_argument
=
ref_pooling_bwd
.
MakeArgument
(
dout_n_c_ho_wo
,
indices_n_c_ho_wo_host
,
din_n_c_hi_wi_host
,
PassThrough
{});
ref_pooling_bwd_invoker
.
Run
(
ref_pooling_bwd_argument
);
out_device_buf
.
FromDevice
(
out_n_c_ho_wo_device
.
mData
.
data
());
indices_device_buf
.
FromDevice
(
indices_n_c_ho_wo_device
.
mData
.
data
());
din_device_buf
.
FromDevice
(
din_n_c_hi_wi_device
.
mData
.
data
());
pass
=
pass
&&
ck
::
utils
::
check_err
(
out_n_c_ho_wo_device
,
out_n_c_ho_wo_host
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
indices_n_c_ho_wo_device
,
indices_n_c_ho_wo_host
);
pass
=
pass
&&
ck
::
utils
::
check_err
(
din_n_c_hi_wi_device
,
din_n_c_hi_wi_host
);
}
return
(
pass
);
};
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/reduction_enums.hpp"
#include "maxpool2d_bwd_common.hpp"
using
InDataType
=
float
;
using
OutDataType
=
float
;
using
IndexDataType
=
int32_t
;
using
ComputeDataType
=
float
;
using
DInDataType
=
float
;
using
DOutDataType
=
float
;
using
InLayout
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
OutLayout
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
static
constexpr
bool
PropagateNan
=
false
;
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
// Pool shape
constexpr
ck
::
index_t
N
=
1
;
constexpr
ck
::
index_t
C
=
1
;
constexpr
ck
::
index_t
Y
=
2
;
constexpr
ck
::
index_t
X
=
2
;
constexpr
ck
::
index_t
Hi
=
31
;
constexpr
ck
::
index_t
Wi
=
31
;
constexpr
ck
::
index_t
window_stride_h
=
2
;
constexpr
ck
::
index_t
window_stride_w
=
2
;
constexpr
ck
::
index_t
in_left_pad_h
=
0
;
constexpr
ck
::
index_t
in_left_pad_w
=
0
;
constexpr
ck
::
index_t
in_right_pad_h
=
1
;
constexpr
ck
::
index_t
in_right_pad_w
=
1
;
constexpr
bool
WindowOverlap
=
Y
>
window_stride_h
||
X
>
window_stride_w
;
constexpr
ck
::
InMemoryDataOperationEnum
MemOp
=
WindowOverlap
?
ck
::
InMemoryDataOperationEnum
::
AtomicAdd
:
ck
::
InMemoryDataOperationEnum
::
Set
;
std
::
cout
<<
"WindowOverlap = "
<<
WindowOverlap
<<
std
::
endl
;
bool
pass
=
maxpool_bwd_test
<
InDataType
,
OutDataType
,
IndexDataType
,
ComputeDataType
,
DInDataType
,
DOutDataType
,
InLayout
,
OutLayout
,
ck
::
ReduceTensorOp
::
MAX
,
PropagateNan
,
MemOp
>
(
do_verification
,
time_kernel
,
N
,
C
,
Y
,
X
,
Hi
,
Wi
,
window_stride_h
,
window_stride_w
,
in_left_pad_h
,
in_left_pad_w
,
in_right_pad_h
,
in_right_pad_w
);
return
(
pass
?
0
:
1
);
}
include/ck/tensor_operation/gpu/device/device_put_element.hpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
Op
>
struct
DevicePutElement
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_input
,
const
void
*
p_indices
,
void
*
p_output
,
index_t
input_length
,
index_t
output_length
,
ElementwiseOperation
elementwise_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_put_element_impl.hpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_put_element.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
// output[indices] = input
template
<
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
MemOp
,
ck
::
index_t
InVectorSize
>
struct
DevicePutElementImpl
:
public
DevicePutElement
<
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
>
{
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
InVectorSize
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
gridSize
,
index_t
blockSize
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
}
using
InGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
,
1
));
using
GridwisePutElement
=
GridwisePutElement_1D
<
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
,
InVectorSize
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
InDataType
*
p_input
,
const
IndexDataType
*
p_indices
,
OutDataType
*
p_output
,
index_t
input_length
,
ElementwiseOperation
elementwise_op
)
:
p_input_
{
p_input
},
p_indices_
{
p_indices
},
p_output_
{
p_output
},
elementwise_op_
{
elementwise_op
},
blockSize_
{
256
},
gridSize_
{
104
}
// FIXME - Calculate the grid size by number of CU in the future
{
in_grid_desc_
=
MakeDescriptor_M
(
input_length
,
gridSize_
,
blockSize_
);
}
const
InDataType
*
p_input_
;
const
IndexDataType
*
p_indices_
;
OutDataType
*
p_output_
;
ElementwiseOperation
elementwise_op_
;
index_t
blockSize_
;
index_t
gridSize_
;
InGrid1dDesc
in_grid_desc_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_put_element_1d
<
GridwisePutElement
,
InGrid1dDesc
,
InDataType
,
IndexDataType
,
OutDataType
,
ElementwiseOperation
,
MemOp
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
in_grid_desc_
,
arg
.
p_input_
,
arg
.
p_indices_
,
arg
.
p_output_
,
arg
.
elementwise_op_
);
return
elapsed_time
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// TODO
ignore
=
pArg
;
return
true
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_input
,
const
void
*
p_indices
,
void
*
p_output
,
index_t
input_length
,
index_t
,
ElementwiseOperation
elementwise_op
)
override
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
InDataType
*>
(
p_input
),
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
OutDataType
*>
(
p_output
),
input_length
,
elementwise_op
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwisePutElementwise1dFunctor
,
typename
InGrid1dDesc
,
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
MemOp
>
__global__
void
kernel_put_element_1d
(
const
InGrid1dDesc
in_grid_1d_desc
,
const
InDataType
*
__restrict__
p_in_global
,
const
IndexDataType
*
__restrict__
p_indices_global
,
OutDataType
*
__restrict__
p_out_global
,
const
ElementwiseOperation
elementwise_op
)
{
GridwisePutElementwise1dFunctor
::
Run
(
in_grid_1d_desc
,
p_in_global
,
p_indices_global
,
p_out_global
,
elementwise_op
);
}
template
<
typename
InGrid1dDesc
,
typename
InDataType
,
typename
IndexDataType
,
typename
OutDataType
,
typename
ElementwiseOperation
,
InMemoryDataOperationEnum
MemOp
,
index_t
InVectorSize
>
struct
GridwisePutElement_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
InVectorSize
>
{}));
__device__
static
void
Run
(
const
InGrid1dDesc
&
in_grid_1d_desc
,
const
InDataType
*
__restrict__
p_in_global
,
const
IndexDataType
*
__restrict__
p_indices_global
,
OutDataType
*
__restrict__
p_out_global
,
const
ElementwiseOperation
&
elementwise_op
)
{
// Global Memory
const
auto
in_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_in_global
,
in_grid_1d_desc
.
GetElementSpaceSize
());
const
auto
indices_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_indices_global
,
in_grid_1d_desc
.
GetElementSpaceSize
(),
NumericLimits
<
IndexDataType
>::
Lowest
());
// VGPR
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
InDataType
,
InVectorSize
,
true
>
in_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexDataType
,
InVectorSize
,
true
>
indices_thread_buf
;
// Thread id, Block id and index
const
index_t
thread_global_id
=
get_thread_global_1d_id
();
const
auto
thread_global_offset
=
make_multi_index
(
thread_global_id
*
InVectorSize
);
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
in_grid_1d_desc
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
InVectorSize
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
auto
in_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
InDataType
,
InDataType
,
decltype
(
in_grid_1d_desc
),
decltype
(
thread_buffer_desc_m
),
Sequence
<
InVectorSize
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InVectorSize
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
in_grid_1d_desc
,
thread_global_offset
};
auto
indices_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
IndexDataType
,
IndexDataType
,
decltype
(
in_grid_1d_desc
),
decltype
(
thread_buffer_desc_m
),
Sequence
<
InVectorSize
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
InVectorSize
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
in_grid_1d_desc
,
thread_global_offset
};
index_t
num_iter
=
M
/
loop_step
;
do
{
in_global_load
.
Run
(
in_grid_1d_desc
,
in_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
in_thread_buf
);
in_global_load
.
MoveSrcSliceWindow
(
in_grid_1d_desc
,
loop_step_index
);
static_for
<
0
,
InVectorSize
,
1
>
{}(
[
&
](
auto
iM
)
{
elementwise_op
(
in_thread_buf
(
iM
),
in_thread_buf
[
iM
]);
});
indices_global_load
.
Run
(
in_grid_1d_desc
,
indices_global_buf
,
thread_buffer_desc_m
,
make_tuple
(
I0
),
indices_thread_buf
);
indices_global_load
.
MoveSrcSliceWindow
(
in_grid_1d_desc
,
loop_step_index
);
static_for
<
0
,
InVectorSize
,
1
>
{}([
&
](
auto
iM
)
{
if
(
indices_thread_buf
[
iM
]
>=
0
)
{
// TODO - Support other operations
static_assert
(
MemOp
==
InMemoryDataOperationEnum
::
Set
||
MemOp
==
InMemoryDataOperationEnum
::
AtomicAdd
);
if
constexpr
(
MemOp
==
InMemoryDataOperationEnum
::
Set
)
{
// User should guarantee each index in p_indices_global is different
*
(
p_out_global
+
indices_thread_buf
[
iM
])
=
ck
::
type_convert
<
OutDataType
>
(
in_thread_buf
[
iM
]);
}
else
if
constexpr
(
MemOp
==
InMemoryDataOperationEnum
::
AtomicAdd
)
{
atomic_add
<
OutDataType
>
(
p_out_global
+
indices_thread_buf
[
iM
],
ck
::
type_convert
<
OutDataType
>
(
in_thread_buf
[
iM
]));
}
else
{
// TODO
}
}
});
}
while
(
--
num_iter
);
}
};
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_maxpool_bwd.hpp
0 → 100644
View file @
1d56022b
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include <vector>
#include <algorithm>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/utility/reduction_functions_accumulate.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
using
namespace
std
;
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
,
typename
ElementwiseOperation
>
struct
ReferenceMaxPoolBwd
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
Tensor
<
DOutDataType
>&
dout
,
const
Tensor
<
IndexDataType
>&
indices
,
Tensor
<
DInDataType
>&
din
,
ElementwiseOperation
elementwise_op
)
:
dout_
(
dout
),
indices_
(
indices
),
din_
(
din
),
elementwise_op_
(
elementwise_op
)
{
}
const
Tensor
<
DOutDataType
>&
dout_
;
const
Tensor
<
IndexDataType
>&
indices_
;
Tensor
<
DInDataType
>&
din_
;
ElementwiseOperation
elementwise_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
int
din_length
=
arg
.
din_
.
GetElementSpaceSize
();
int
dout_length
=
arg
.
dout_
.
GetElementSpaceSize
();
for
(
int
i
=
0
;
i
<
dout_length
;
++
i
)
{
int
index
=
arg
.
indices_
.
mData
[
i
];
if
(
index
>=
0
&&
index
<
din_length
)
arg
.
din_
.
mData
[
index
]
+=
arg
.
dout_
.
mData
[
i
];
}
return
0
;
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/* stream_config */
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
Tensor
<
DOutDataType
>&
dout
,
const
Tensor
<
IndexDataType
>&
indices
,
Tensor
<
DInDataType
>&
din
,
ElementwiseOperation
elementwise_op
)
{
return
Argument
{
dout
,
indices
,
din
,
elementwise_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"ReferenceMaxPoolBwd"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment