Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
c4d610be
Commit
c4d610be
authored
May 18, 2022
by
rocking
Browse files
Move thread per block to the parameter of constructor
parent
83f75313
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
21 additions
and
16 deletions
+21
-16
example/19_binary_elementwise/broadcast_add_2d.cpp
example/19_binary_elementwise/broadcast_add_2d.cpp
+1
-2
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+1
-2
example/19_binary_elementwise/elementwise_add_4d.cpp
example/19_binary_elementwise/elementwise_add_4d.cpp
+1
-2
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+18
-10
No files found.
example/19_binary_elementwise/broadcast_add_2d.cpp
View file @
c4d610be
...
...
@@ -101,8 +101,7 @@ int main()
{
Stride
,
1
},
{
0
,
1
},
// broadcast in first dimension
{
Stride
,
1
},
Add
{},
256
);
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
...
...
example/19_binary_elementwise/elementwise_add_1d.cpp
View file @
c4d610be
...
...
@@ -80,8 +80,7 @@ int main()
{
1
},
{
1
},
{
1
},
Add
{},
256
);
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
...
...
example/19_binary_elementwise/elementwise_add_4d.cpp
View file @
c4d610be
...
...
@@ -82,8 +82,7 @@ int main()
ck
::
to_int_vector
(
a_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
b_m
.
mDesc
.
GetStrides
()),
ck
::
to_int_vector
(
c_m
.
mDesc
.
GetStrides
()),
Add
{},
256
);
Add
{});
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
c4d610be
...
...
@@ -19,6 +19,11 @@ template <typename ADataType,
index_t
ScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOperator
{
DeviceBinaryElementwise
(
index_t
threadPerBlock
=
256
)
:
BaseOperator
(),
threadPerBlock_
(
threadPerBlock
)
{
}
static
constexpr
auto
I0
=
Number
<
0
>
{};
template
<
typename
Desc_M0
>
...
...
@@ -85,12 +90,11 @@ struct DeviceBinaryElementwise : public BaseOperator
p_b_
(
p_b
),
p_c_
(
p_c
),
functor_
(
functor
),
threadPerBlock_
(
threadPerBlock
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
threadPerBlock
_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
threadPerBlock
_
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
threadPerBlock
_
);
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
threadPerBlock
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
threadPerBlock
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
threadPerBlock
);
}
const
ADataType
*
p_a_
;
...
...
@@ -100,12 +104,13 @@ struct DeviceBinaryElementwise : public BaseOperator
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
ElementwiseFunctor
functor_
;
index_t
threadPerBlock_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
Invoker
(
index_t
threadPerBlock
)
:
BaseInvoker
(),
threadPerBlock_
(
threadPerBlock
)
{}
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseBinEltwise
,
...
...
@@ -118,7 +123,7 @@ struct DeviceBinaryElementwise : public BaseOperator
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
threadPerBlock_
),
dim3
(
threadPerBlock_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
...
...
@@ -136,6 +141,8 @@ struct DeviceBinaryElementwise : public BaseOperator
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
index_t
threadPerBlock_
;
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -161,8 +168,7 @@ struct DeviceBinaryElementwise : public BaseOperator
std
::
vector
<
int
>
stride_a
,
std
::
vector
<
int
>
stride_b
,
std
::
vector
<
int
>
stride_c
,
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
ElementwiseFunctor
functor
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
...
@@ -172,12 +178,12 @@ struct DeviceBinaryElementwise : public BaseOperator
stride_b
,
stride_c
,
functor
,
threadPerBlock
);
threadPerBlock
_
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
return
std
::
make_unique
<
Invoker
>
(
Invoker
{
threadPerBlock_
});
}
std
::
string
GetTypeString
()
const
override
...
...
@@ -193,6 +199,8 @@ struct DeviceBinaryElementwise : public BaseOperator
return
str
.
str
();
}
index_t
threadPerBlock_
;
};
}
// namespace device
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment