Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
38962b98
"...composable_kernel.git" did not exist on "29087570093f38075ed25d48b3f5c4d2885e47fa"
Commit
38962b98
authored
Jun 15, 2023
by
rocking
Browse files
Calculate gridsize according to the number of CU
parent
3550cefe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
22 deletions
+21
-22
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
..._operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
+21
-22
No files found.
include/ck/tensor_operation/gpu/device/impl/device_index_pool_bwd_impl.hpp
View file @
38962b98
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/stream_utility.hpp"
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
...
@@ -94,15 +95,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -94,15 +95,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
:
p_dout_
{
p_dout
},
:
p_dout_
{
p_dout
},
p_indices_
{
p_indices
},
p_indices_
{
p_indices
},
p_din_
{
p_din
},
p_din_
{
p_din
},
dout_length_raw_
{
dout_length
},
din_length_raw_
{
din_length
},
din_length_raw_
{
din_length
},
blockSize_
{
256
},
blockSize_
{
256
},
gridSize_
{
104
},
// FIXME - Calculate the grid size by number of CU in the future
windowOverlap_
{
false
}
windowOverlap_
{
false
}
{
{
index_t
loop_step
=
gridSize_
*
blockSize_
*
InOutVectorSize
;
din_grid_desc_
=
MakeDescriptor_M
(
din_length
,
loop_step
);
dout_grid_desc_
=
MakeDescriptor_M
(
dout_length
,
loop_step
);
for
(
size_t
i
=
0
;
i
<
window_lengths
.
size
();
++
i
)
for
(
size_t
i
=
0
;
i
<
window_lengths
.
size
();
++
i
)
{
{
windowOverlap_
|=
window_lengths
.
at
(
i
)
>
window_strides
.
at
(
i
);
windowOverlap_
|=
window_lengths
.
at
(
i
)
>
window_strides
.
at
(
i
);
...
@@ -112,18 +109,21 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -112,18 +109,21 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
const
DOutDataType
*
p_dout_
;
const
DOutDataType
*
p_dout_
;
const
IndexDataType
*
p_indices_
;
const
IndexDataType
*
p_indices_
;
DInDataType
*
p_din_
;
DInDataType
*
p_din_
;
index_t
dout_length_raw_
;
index_t
din_length_raw_
;
index_t
din_length_raw_
;
index_t
blockSize_
;
index_t
blockSize_
;
index_t
gridSize_
;
bool
windowOverlap_
;
bool
windowOverlap_
;
InOutGrid1dDesc
din_grid_desc_
;
InOutGrid1dDesc
dout_grid_desc_
;
};
};
struct
Invoker
:
public
BaseInvoker
struct
Invoker
:
public
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
{
index_t
gridSize
=
getAvailableComputeUnitCount
(
stream_config
);
index_t
loop_step
=
gridSize
*
arg
.
blockSize_
*
InOutVectorSize
;
InOutGrid1dDesc
din_grid_desc
=
MakeDescriptor_M
(
arg
.
din_length_raw_
,
loop_step
);
InOutGrid1dDesc
dout_grid_desc
=
MakeDescriptor_M
(
arg
.
dout_length_raw_
,
loop_step
);
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
if
constexpr
(
is_same_v
<
DInDataType
,
float
>
||
is_same_v
<
DInDataType
,
double
>
)
{
{
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
hip_check_error
(
hipMemsetAsync
(
arg
.
p_din_
,
...
@@ -142,10 +142,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -142,10 +142,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
put_kernel
,
dim3
(
arg
.
gridSize
_
),
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
dim3
(
arg
.
blockSize_
),
0
,
0
,
arg
.
dout_grid_desc
_
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_indices_
,
arg
.
p_din_
,
arg
.
p_din_
,
...
@@ -162,10 +162,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -162,10 +162,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
put_kernel
,
dim3
(
arg
.
gridSize
_
),
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
dim3
(
arg
.
blockSize_
),
0
,
0
,
arg
.
dout_grid_desc
_
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_indices_
,
arg
.
p_din_
,
arg
.
p_din_
,
...
@@ -203,10 +203,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -203,10 +203,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
float
elapsed_time
=
launch_and_time_kernel
(
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
stream_config
,
put_kernel
,
put_kernel
,
dim3
(
arg
.
gridSize
_
),
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
dim3
(
arg
.
blockSize_
),
0
,
0
,
arg
.
dout_grid_desc
_
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_indices_
,
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
...
@@ -215,11 +215,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -215,11 +215,11 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
elapsed_time
+=
launch_and_time_kernel
(
elapsed_time
+=
launch_and_time_kernel
(
stream_config
,
stream_config
,
cast_kernel
,
cast_kernel
,
dim3
(
arg
.
gridSize
_
),
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
dim3
(
arg
.
blockSize_
),
0
,
0
,
ck
::
make_tuple
(
arg
.
din_grid_desc
_
),
ck
::
make_tuple
(
din_grid_desc
),
ck
::
make_tuple
(
arg
.
din_grid_desc
_
),
ck
::
make_tuple
(
din_grid_desc
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
static_cast
<
DInDataType_AutomicAddPreCast
*>
(
arg
.
p_workspace_
),
arg
.
p_din_
,
arg
.
p_din_
,
UnaryConvert
{});
UnaryConvert
{});
...
@@ -242,10 +242,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -242,10 +242,10 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
return
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
put_kernel
,
put_kernel
,
dim3
(
arg
.
gridSize
_
),
dim3
(
gridSize
),
dim3
(
arg
.
blockSize_
),
dim3
(
arg
.
blockSize_
),
0
,
0
,
arg
.
dout_grid_desc
_
,
dout_grid_desc
,
arg
.
p_dout_
,
arg
.
p_dout_
,
arg
.
p_indices_
,
arg
.
p_indices_
,
arg
.
p_din_
,
arg
.
p_din_
,
...
@@ -277,9 +277,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
...
@@ -277,9 +277,8 @@ struct DeviceIndexPoolBwdImpl : public DeviceIndexPoolBwd<DOutDataType, IndexDat
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
index_t
din_length
=
pArg
->
din_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
if
(
pArg
->
din_length_raw_
%
InOutVectorSize
!=
0
||
index_t
dout_length
=
pArg
->
dout_grid_desc_
.
GetTransforms
()[
I0
].
GetUpperLengths
()[
I0
];
pArg
->
dout_length_raw_
%
InOutVectorSize
!=
0
)
if
(
din_length
%
InOutVectorSize
!=
0
||
dout_length
%
InOutVectorSize
!=
0
)
{
{
return
false
;
return
false
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment