Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f919809d
Commit
f919809d
authored
Apr 26, 2022
by
rocking
Browse files
Move threadPerBlock to argument
parent
a41f5481
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
19 deletions
+24
-19
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
+4
-4
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+2
-1
include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp
...sor_operation/gpu/device/device_binary_elementwise_2d.hpp
+18
-14
No files found.
example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp
View file @
f919809d
...
@@ -179,7 +179,6 @@ using DeviceElementwiseSubExpInstance =
...
@@ -179,7 +179,6 @@ using DeviceElementwiseSubExpInstance =
CDataType
,
CDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
,
SubExp
,
SubExp
,
256
,
8
>
;
8
>
;
using
DeviceElementwiseDivInstance
=
using
DeviceElementwiseDivInstance
=
...
@@ -188,7 +187,6 @@ using DeviceElementwiseDivInstance =
...
@@ -188,7 +187,6 @@ using DeviceElementwiseDivInstance =
CDataType
,
CDataType
,
EltwiseComputeDataType
,
EltwiseComputeDataType
,
Div
,
Div
,
256
,
8
>
;
8
>
;
using
HostGemmInstance
=
ck
::
tensor_operation
::
host
::
using
HostGemmInstance
=
ck
::
tensor_operation
::
host
::
...
@@ -416,7 +414,8 @@ int main(int argc, char* argv[])
...
@@ -416,7 +414,8 @@ int main(int argc, char* argv[])
{
StrideC
,
1
},
{
StrideC
,
1
},
{
0
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
{
StrideC
,
1
},
SubExp
{});
SubExp
{},
256
);
if
(
!
broadcastSubExp
.
IsSupportedArgument
(
broadcastSubExp_argument_ptr
.
get
()))
if
(
!
broadcastSubExp
.
IsSupportedArgument
(
broadcastSubExp_argument_ptr
.
get
()))
{
{
...
@@ -466,7 +465,8 @@ int main(int argc, char* argv[])
...
@@ -466,7 +465,8 @@ int main(int argc, char* argv[])
{
StrideC
,
1
},
{
StrideC
,
1
},
{
0
,
1
},
{
0
,
1
},
{
StrideC
,
1
},
{
StrideC
,
1
},
Div
{});
Div
{},
256
);
if
(
!
broadcastDiv
.
IsSupportedArgument
(
broadcastDiv_argument_ptr
.
get
()))
if
(
!
broadcastDiv
.
IsSupportedArgument
(
broadcastDiv_argument_ptr
.
get
()))
{
{
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
f919809d
...
@@ -19,7 +19,8 @@ struct DeviceBinaryElementwise : public BaseOperator
...
@@ -19,7 +19,8 @@ struct DeviceBinaryElementwise : public BaseOperator
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
shape_b
,
const
std
::
vector
<
int
>&
shape_b
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_b
,
ElementwiseFunctor
functor
)
=
0
;
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp
View file @
f919809d
...
@@ -15,7 +15,6 @@ template <typename ADataType,
...
@@ -15,7 +15,6 @@ template <typename ADataType,
typename
CDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
typename
ElementwiseFunctor
,
index_t
ThreadPerBlock
,
index_t
ScalarPerVector
>
index_t
ScalarPerVector
>
struct
DeviceBinaryElementwise_2D
:
public
DeviceBinaryElementwise
<
ElementwiseFunctor
>
struct
DeviceBinaryElementwise_2D
:
public
DeviceBinaryElementwise
<
ElementwiseFunctor
>
{
{
...
@@ -23,7 +22,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -23,7 +22,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
)
index_t
gridSize
,
index_t
threadPerBlock
)
{
{
const
int
m
=
shape
[
0
];
const
int
m
=
shape
[
0
];
const
int
n
=
shape
[
1
];
const
int
n
=
shape
[
1
];
...
@@ -41,7 +41,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -41,7 +41,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// pad
// pad
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
T
hreadPerBlock
*
ScalarPerVector
;
const
index_t
loop_step
=
gridSize
*
t
hreadPerBlock
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
transform_tensor_descriptor
(
desc_m0
,
...
@@ -51,7 +51,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -51,7 +51,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
return
desc_m0_pad
;
return
desc_m0_pad
;
}
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
));
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
BDataType
,
CDataType
,
CDataType
,
...
@@ -69,16 +69,18 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -69,16 +69,18 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_c
,
const
std
::
vector
<
int
>&
stride_c
,
ElementwiseFunctor
functor
)
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
:
p_a_
(
p_a
),
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_b_
(
p_b
),
p_c_
(
p_c
),
p_c_
(
p_c
),
functor_
(
functor
),
functor_
(
functor
),
threadPerBlock_
(
threadPerBlock
),
gridSize_
(
128
)
// FIXME - Calculate the grid size by number of CU in the future
gridSize_
(
128
)
// FIXME - Calculate the grid size by number of CU in the future
{
{
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
);
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
threadPerBlock_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
threadPerBlock_
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
threadPerBlock_
);
}
}
const
ADataType
*
p_a_
;
const
ADataType
*
p_a_
;
...
@@ -88,6 +90,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -88,6 +90,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
ElementwiseFunctor
functor_
;
ElementwiseFunctor
functor_
;
index_t
threadPerBlock_
;
index_t
gridSize_
;
index_t
gridSize_
;
};
};
...
@@ -102,12 +105,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -102,12 +105,12 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
CDataType
,
CDataType
,
GridDesc_M0
,
GridDesc_M0
,
ElementwiseFunctor
>
;
ElementwiseFunctor
>
;
float
avgTime
=
0
;
float
avgTime
=
0
;
if
(
nrepeat
==
0
)
if
(
nrepeat
==
0
)
{
{
launch_kernel
(
kernel
,
launch_kernel
(
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
gridSize_
),
dim3
(
T
hreadPerBlock
),
dim3
(
arg
.
t
hreadPerBlock
_
),
0
,
0
,
arg
.
p_a_
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_b_
,
...
@@ -122,7 +125,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -122,7 +125,7 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
avgTime
=
launch_and_time_kernel
(
kernel
,
avgTime
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
gridSize_
),
dim3
(
T
hreadPerBlock
),
dim3
(
arg
.
t
hreadPerBlock
_
),
0
,
0
,
arg
.
p_a_
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_b_
,
...
@@ -164,7 +167,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -164,7 +167,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_c
,
const
std
::
vector
<
int
>&
stride_c
,
ElementwiseFunctor
functor
)
override
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -173,7 +177,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -173,7 +177,8 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
stride_a
,
stride_a
,
stride_b
,
stride_b
,
stride_c
,
stride_c
,
functor
);
functor
,
threadPerBlock
);
}
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
...
@@ -188,7 +193,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
...
@@ -188,7 +193,6 @@ struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFu
// clang-format off
// clang-format off
str
<<
"DeviceBinaryElementwise_2D"
str
<<
"DeviceBinaryElementwise_2D"
<<
"<"
<<
"<"
<<
"ThreadPerBlock = "
<<
ThreadPerBlock
<<
"ScalarPerVector = "
<<
ScalarPerVector
<<
"ScalarPerVector = "
<<
ScalarPerVector
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment