Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bc7cc7c0
"...composable_kernel_rocm.git" did not exist on "a1c07e8d913cd03011f4ea3d45033ab4e765e9f1"
Commit
bc7cc7c0
authored
May 24, 2022
by
wangshaojie6
Browse files
add GetWorkSpaceSize to base arg and make an example on convnd_bwd_weight
parent
ba58a93f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
110 additions
and
9 deletions
+110
-9
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
+47
-9
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+2
-0
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
...e_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
+61
-0
No files found.
example/20_convnd_bwd_weight_xdl/convnd_bwd_weight_xdl.cpp
View file @
bc7cc7c0
...
...
@@ -257,11 +257,11 @@ int main(int argc, char* argv[])
case
0
:
break
;
case
1
:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_2
<
OutDataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
Wei
DataType
>
{
-
2
,
2
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_2
<
In
DataType
>
{
-
2
,
2
});
break
;
default:
out_n_k_ho_wo
.
GenerateTensorValue
(
GeneratorTensor_1
<
OutDataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
Wei
DataType
>
{
1
});
in_n_c_hi_wi
.
GenerateTensorValue
(
GeneratorTensor_1
<
In
DataType
>
{
1
});
}
DeviceMem
in_device_buf
(
sizeof
(
InDataType
)
*
in_n_c_hi_wi
.
mDesc
.
GetElementSpace
());
...
...
@@ -296,15 +296,53 @@ int main(int argc, char* argv[])
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
// alloc work space
size_t
bwd_weight_workspace_size
=
conv
->
GetWorkSpaceSize
(
argument
.
get
());
float
ave_time
=
0.
f
;
if
(
bwd_weight_workspace_size
>
0
)
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
DeviceMem
wei_work_space_device_buf
(
bwd_weight_workspace_size
);
wei_work_space_device_buf
.
SetZero
();
argument
=
conv
->
MakeArgumentPointer
(
static_cast
<
InDataType
*>
(
in_device_buf
.
GetDeviceBuffer
()),
static_cast
<
AccDataType
*>
(
wei_work_space_device_buf
.
GetDeviceBuffer
()),
static_cast
<
OutDataType
*>
(
out_device_buf
.
GetDeviceBuffer
()),
params
.
N_
,
params
.
K_
,
params
.
C_
,
params
.
input_spatial_lengths_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
,
params
.
conv_filter_strides_
,
params
.
conv_filter_dilations_
,
params
.
input_left_pads_
,
params
.
input_right_pads_
,
InElementOp
{},
WeiElementOp
{},
OutElementOp
{},
split_k
);
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
float
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
else
{
if
(
!
conv
->
IsSupportedArgument
(
argument
.
get
()))
{
std
::
cout
<<
"wrong! device_conv with the specified compilation parameters does "
"not support this Conv problem"
<<
std
::
endl
;
return
1
;
}
ave_time
=
invoker
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
}
std
::
size_t
flop
=
ck
::
utils
::
conv
::
get_flops
(
params
.
N_
,
params
.
C_
,
params
.
K_
,
params
.
filter_spatial_lengths_
,
output_spatial_lengths
);
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
bc7cc7c0
...
...
@@ -40,6 +40,8 @@ struct BaseOperator
virtual
bool
IsSupportedArgument
(
const
BaseArgument
*
)
{
return
false
;
}
virtual
std
::
string
GetTypeString
()
const
{
return
""
;
}
virtual
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
)
const
{
return
0
;
}
virtual
~
BaseOperator
()
{}
};
...
...
include/ck/tensor_operation/gpu/device/device_convnd_backward_weight_xdl_c_shuffle_nhwc_kyxc_nhwk.hpp
View file @
bc7cc7c0
...
...
@@ -1175,6 +1175,67 @@ struct DeviceConvndBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
return
str
.
str
();
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
1
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
sizeof
(
float
);
}
else
{
WorkSpaceSize
=
arg
.
Conv_K_
*
0
;
}
return
WorkSpaceSize
;
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
2
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
arg
.
filter_spatial_lengths_
[
1
]
*
sizeof
(
float
);
}
else
{
WorkSpaceSize
=
arg
.
Conv_K_
*
0
;
}
return
WorkSpaceSize
;
}
template
<
ck
::
index_t
NDim
,
typename
ck
::
enable_if
<
NDim
==
3
,
bool
>
::
type
=
false
>
static
size_t
GetWorkSpaceSize
(
const
Argument
&
arg
)
{
size_t
WorkSpaceSize
=
0
;
if
(
arg
.
k_batch_
>
1
)
{
if
constexpr
(
std
::
is_same
<
InDataType
,
ck
::
bhalf_t
>::
value
)
{
WorkSpaceSize
=
arg
.
Conv_K_
*
arg
.
Conv_C_
*
arg
.
filter_spatial_lengths_
[
0
]
*
arg
.
filter_spatial_lengths_
[
1
]
*
arg
.
filter_spatial_lengths_
[
2
]
*
sizeof
(
float
);
}
else
{
WorkSpaceSize
=
arg
.
Conv_K_
*
0
;
}
}
else
{
WorkSpaceSize
=
arg
.
Conv_K_
*
0
;
}
return
WorkSpaceSize
;
}
size_t
GetWorkSpaceSize
(
const
BaseArgument
*
p_arg
)
const
override
final
{
return
GetWorkSpaceSize
<
NumDimSpatial
>
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
}
};
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment