Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7f09b8a0
Commit
7f09b8a0
authored
Jun 06, 2023
by
rocking
Browse files
Support f16 and bf16
parent
acd980fc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
183 additions
and
61 deletions
+183
-61
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
+6
-1
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
..._operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
+171
-54
No files found.
example/49_maxpool2d_bwd/maxpool2d_bwd_common.hpp
View file @
7f09b8a0
...
...
@@ -174,6 +174,12 @@ bool maxpool_bwd_test(bool do_verification,
"not support this problem"
);
}
size_t
pool_bwd_workspace_sz
=
pool_bwd
.
GetWorkSpaceSize
(
pool_bwd_argument_ptr
.
get
());
DeviceMem
pool_bwd_workspace_device_buf
(
pool_bwd_workspace_sz
);
pool_bwd_workspace_device_buf
.
SetZero
();
pool_bwd
.
SetWorkSpacePointer
(
pool_bwd_argument_ptr
.
get
(),
pool_bwd_workspace_device_buf
.
GetDeviceBuffer
());
float
ave_time_bwd
=
pool_bwd_invoker_ptr
->
Run
(
pool_bwd_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
...
...
@@ -204,7 +210,6 @@ bool maxpool_bwd_test(bool do_verification,
window_strides
,
input_left_pads
,
input_right_pads
);
ref_pooling_fwd_invoker
.
Run
(
ref_pooling_fwd_argument
);
using
ReferencePoolingBwdInstance
=
ck
::
tensor_operation
::
host
::
...
...
example/49_maxpool2d_bwd/maxpool2d_bwd_fp32.cpp
View file @
7f09b8a0
...
...
@@ -9,8 +9,8 @@
#include "maxpool2d_bwd_common.hpp"
using
InDataType
=
floa
t
;
using
OutDataType
=
floa
t
;
using
InDataType
=
ck
::
half_
t
;
using
OutDataType
=
ck
::
half_
t
;
using
IndexDataType
=
int32_t
;
using
ComputeDataType
=
float
;
using
DInDataType
=
float
;
...
...
@@ -29,12 +29,12 @@ int main()
// Pool shape
ck
::
index_t
N
=
1
;
ck
::
index_t
C
=
1
;
ck
::
index_t
Y
=
2
;
ck
::
index_t
X
=
2
;
ck
::
index_t
Y
=
3
;
ck
::
index_t
X
=
3
;
ck
::
index_t
Hi
=
31
;
ck
::
index_t
Wi
=
31
;
ck
::
index_t
window_stride_h
=
2
;
ck
::
index_t
window_stride_w
=
2
;
ck
::
index_t
window_stride_h
=
1
;
ck
::
index_t
window_stride_w
=
1
;
ck
::
index_t
in_left_pad_h
=
0
;
ck
::
index_t
in_left_pad_w
=
0
;
ck
::
index_t
in_right_pad_h
=
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
View file @
7f09b8a0
...
...
@@ -11,6 +11,7 @@
#include "ck/tensor_operation/gpu/device/reduction_operator_mapping.hpp"
#include "ck/tensor_operation/gpu/device/device_index_pool_bwd.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
...
...
@@ -23,22 +24,24 @@ namespace device {
template
<
typename
DOutDataType
,
typename
IndexDataType
,
typename
DInDataType
,
ck
::
index_t
InVectorSize
>
ck
::
index_t
In
Out
VectorSize
>
struct
DeviceIndexPoolBwdImpl
:
public
DeviceIndexPoolBwd
<
DOutDataType
,
IndexDataType
,
DInDataType
>
{
static_assert
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
,
"Data type is not supported!"
);
using
DInDataType_AutomicAddPreCast
=
conditional_t
<
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
,
DInDataType
,
float
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
UnaryConvert
=
ck
::
tensor_operation
::
element_wise
::
UnaryConvert
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M
>
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
gridSize
,
index_t
blockSize
)
static
auto
PadDescriptor_M_1d
(
Desc_M
desc_m
,
index_t
loop_step
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
blockSize
*
InVectorSize
;
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
m
=
desc_m
.
GetLength
(
I0
);
const
auto
pad
=
math
::
integer_least_multiple
(
m
,
loop_step
)
-
m
;
const
auto
desc_m_pad
=
transform_tensor_descriptor
(
desc_m
,
make_tuple
(
make_right_pad_transform
(
m
,
pad
)),
...
...
@@ -47,29 +50,38 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return
desc_m_pad
;
}
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
gridSize
,
index_t
blockSize
)
static
auto
MakeDescriptor_M
(
index_t
length
,
index_t
loop_step
)
{
const
auto
desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
length
));
return
PadDescriptor_M_1d
(
desc_m
,
gridSize
,
blockSize
);
return
PadDescriptor_M_1d
(
desc_m
,
loop_step
);
}
using
OutGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
,
1
));
using
In
OutGrid1dDesc
=
decltype
(
MakeDescriptor_M
(
1
,
1
));
using
GridwisePutElementSet
=
GridwisePutElement_1D
<
OutGrid1dDesc
,
using
GridwisePutElementSet
=
GridwisePutElement_1D
<
In
OutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
InVectorSize
>
;
In
Out
VectorSize
>
;
using
GridwisePutElementAtomicAdd
=
GridwisePutElement_1D
<
OutGrid1dDesc
,
using
GridwisePutElementAtomicAdd
=
GridwisePutElement_1D
<
In
OutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
DInDataType
_AutomicAddPreCast
,
PassThrough
,
InMemoryDataOperationEnum
::
AtomicAdd
,
InVectorSize
>
;
InOutVectorSize
>
;
using
GridwiseCasting
=
GridwiseElementwise_1D
<
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
,
InOutVectorSize
,
Sequence
<
InOutVectorSize
>
,
Sequence
<
InOutVectorSize
>>
;
struct
Argument
:
public
BaseArgument
{
...
...
@@ -77,6 +89,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const
IndexDataType
*
p_indices
,
DInDataType
*
p_din
,
index_t
dout_length
,
index_t
din_length
,
const
std
::
vector
<
ck
::
index_t
>&
window_lengths
,
const
std
::
vector
<
ck
::
index_t
>&
window_strides
)
:
p_dout_
{
p_dout
},
...
...
@@ -86,7 +99,9 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
gridSize_
{
104
},
// FIXME - Calculate the grid size by number of CU in the future
windowOverlap_
{
false
}
{
dout_grid_desc_
=
MakeDescriptor_M
(
dout_length
,
gridSize_
,
blockSize_
);
index_t
loop_step
=
gridSize_
*
blockSize_
*
InOutVectorSize
;
din_grid_desc_
=
MakeDescriptor_M
(
din_length
,
loop_step
);
dout_grid_desc_
=
MakeDescriptor_M
(
dout_length
,
loop_step
);
for
(
size_t
i
=
0
;
i
<
window_lengths
.
size
();
++
i
)
{
...
...
@@ -100,45 +115,126 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
index_t
blockSize_
;
index_t
gridSize_
;
bool
windowOverlap_
;
OutGrid1dDesc
dout_grid_desc_
;
InOutGrid1dDesc
din_grid_desc_
;
InOutGrid1dDesc
dout_grid_desc_
;
};
struct
Invoker
:
public
BaseInvoker
{
constexpr
auto
KernelSelector
(
bool
windowOverlap
)
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{}
)
{
if
(
windowOverlap
)
return
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
OutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
if
(
arg
.
windowOverlap_
)
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
dout_grid_desc_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
else
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
dout_grid_desc_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
else
return
kernel_put_element_1d
<
GridwisePutElementSet
,
OutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
}
{
if
(
arg
.
windowOverlap_
)
{
if
(
arg
.
p_workspace_
==
nullptr
)
throw
std
::
runtime_error
(
"wrong! WorkSpace pointer has not been set"
);
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
KernelSelector
(
arg
.
windowOverlap_
);
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
dout_grid_desc_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
return
elapsed_time
;
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementAtomicAdd
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType_AutomicAddPreCast
,
PassThrough
>
;
const
auto
cast_kernel
=
kernel_elementwise_1d
<
GridwiseCasting
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
InOutGrid1dDesc
>
,
Tuple
<
const
DInDataType_AutomicAddPreCast
*>
,
Tuple
<
DInDataType
*>
,
UnaryConvert
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
dout_grid_desc_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
PassThrough
{});
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
cast_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
ck
::
make_tuple
(
arg
.
din_grid_desc_
),
ck
::
make_tuple
(
arg
.
din_grid_desc_
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
arg
.
p_din_
,
UnaryConvert
{});
return
elapsed_time
;
}
else
{
const
auto
put_kernel
=
kernel_put_element_1d
<
GridwisePutElementSet
,
InOutGrid1dDesc
,
DOutDataType
,
IndexDataType
,
DInDataType
,
PassThrough
>
;
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
dout_grid_desc_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_din_
,
PassThrough
{});
}
}
}
float
Run
(
const
BaseArgument
*
p_arg
,
...
...
@@ -148,11 +244,31 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
}
};
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
pArg
)
const
override
{
const
Argument
*
pArg_
=
dynamic_cast
<
const
Argument
*>
(
pArg
);
bool
needCast
=
pArg_
->
windowOverlap_
&&
!
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
);
if
(
!
needCast
)
return
0
;
else
{
index_t
din_length
=
pArg_
->
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
return
din_length
*
sizeof
(
DInDataType_AutomicAddPreCast
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
// TODO
ignore
=
pArg
;
index_t
din_length
=
pArg
->
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
index_t
dout_length
=
pArg
->
dout_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
if
(
din_length
%
InOutVectorSize
!=
0
||
dout_length
%
InOutVectorSize
!=
0
)
{
return
false
;
}
return
true
;
}
...
...
@@ -161,7 +277,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const
void
*
p_indices
,
void
*
p_din
,
index_t
dout_length
,
index_t
,
index_t
din_length
,
std
::
vector
<
ck
::
index_t
>
window_lengths
,
std
::
vector
<
ck
::
index_t
>
window_strides
)
override
{
...
...
@@ -169,6 +285,7 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
static_cast
<
const
IndexDataType
*>
(
p_indices
),
static_cast
<
DInDataType
*>
(
p_din
),
dout_length
,
din_length
,
window_lengths
,
window_strides
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment