Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b3767dbe
Commit
b3767dbe
authored
May 17, 2022
by
myamlak
Browse files
Merge remote-tracking branch 'origin/eltwise_op' into myamlak/cgemm
parents
e00a943e
ecdfe960
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
662 additions
and
1 deletion
+662
-1
example/19_binary_elementwise/CMakeLists.txt
example/19_binary_elementwise/CMakeLists.txt
+2
-0
example/19_binary_elementwise/broadcast_add_2d.cpp
example/19_binary_elementwise/broadcast_add_2d.cpp
+137
-0
example/19_binary_elementwise/elementwise_add_1d.cpp
example/19_binary_elementwise/elementwise_add_1d.cpp
+119
-0
example/20_cgemm/CMakeLists.txt
example/20_cgemm/CMakeLists.txt
+0
-0
example/20_cgemm/cgemm_xdl_bf16.cpp
example/20_cgemm/cgemm_xdl_bf16.cpp
+0
-0
example/CMakeLists.txt
example/CMakeLists.txt
+2
-1
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+229
-0
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
...r_operation/gpu/element/binary_element_wise_operation.hpp
+19
-0
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
...sor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
+150
-0
include/ck/utility/get_id.hpp
include/ck/utility/get_id.hpp
+4
-0
No files found.
example/19_binary_elementwise/CMakeLists.txt
0 → 100644
View file @
b3767dbe
add_example_executable
(
example_broadcast_add_2d broadcast_add_2d.cpp
)
add_example_executable
(
example_elementwise_add_1d elementwise_add_1d.cpp
)
\ No newline at end of file
example/19_binary_elementwise/broadcast_add_2d.cpp
0 → 100644
View file @
b3767dbe
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
F16
,
F16
,
CDataType
,
EltwiseComputeDataType
,
Add
,
2
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
,
int
broadcastDim
>
void
host_broadcast2D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
int
N
,
Functor
functor
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
Amn
=
static_cast
<
ComputeDataType
>
(
A
(
m
,
n
));
ComputeDataType
Cmn
=
0
;
if
constexpr
(
broadcastDim
==
0
)
{
ComputeDataType
Bn
=
static_cast
<
ComputeDataType
>
(
B
(
n
));
functor
(
Cmn
,
Amn
,
Bn
);
}
else
{
ComputeDataType
Bm
=
static_cast
<
ComputeDataType
>
(
B
(
m
));
functor
(
Cmn
,
Amn
,
Bm
);
}
C
(
m
,
n
)
=
static_cast
<
ComputeDataType
>
(
Cmn
);
}
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
ck
::
index_t
M
=
1024
;
ck
::
index_t
N
=
1024
;
ck
::
index_t
Stride
=
1024
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
auto
f_host_tensor_descriptor2d
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
};
Tensor
<
ABDataType
>
a_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
Tensor
<
ABDataType
>
b_n
(
f_host_tensor_descriptor1d
(
N
,
1
));
Tensor
<
CDataType
>
c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
a_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_n_device_buf
(
sizeof
(
ABDataType
)
*
a_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
b_n_device_buf
(
sizeof
(
ABDataType
)
*
b_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n
.
mDesc
.
GetElementSpace
());
a_m_n_device_buf
.
ToDevice
(
a_m_n
.
mData
.
data
());
b_n_device_buf
.
ToDevice
(
b_n
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_n_device_buf
.
GetDeviceBuffer
(),
b_n_device_buf
.
GetDeviceBuffer
(),
c_m_n_device_buf
.
GetDeviceBuffer
(),
{
M
,
N
},
{
Stride
,
1
},
{
0
,
1
},
// broadcast in first dimension
{
Stride
,
1
},
Add
{},
256
);
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_n_device_buf
.
FromDevice
(
c_m_n
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m_n
(
f_host_tensor_descriptor2d
(
M
,
N
,
Stride
));
host_broadcast2D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
,
0
>
(
host_c_m_n
,
a_m_n
,
b_n
,
M
,
N
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m_n
.
mData
,
host_c_m_n
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/19_binary_elementwise/elementwise_add_1d.cpp
0 → 100644
View file @
b3767dbe
#include <iostream>
#include <numeric>
#include <initializer_list>
#include <cstdlib>
#include <stdlib.h>
#include <half.hpp>
#include <math.h>
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "device_tensor.hpp"
#include "binary_element_wise_operation.hpp"
#include "device_binary_elementwise.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
ABDataType
=
F16
;
using
CDataType
=
F16
;
using
EltwiseComputeDataType
=
F32
;
using
Add
=
ck
::
tensor_operation
::
binary_element_wise
::
Add
;
using
DeviceElementwiseAddInstance
=
ck
::
tensor_operation
::
device
::
DeviceBinaryElementwise
<
F16
,
F16
,
CDataType
,
EltwiseComputeDataType
,
Add
,
1
,
8
>
;
template
<
typename
HostTensorA
,
typename
HostTensorB
,
typename
HostTensorC
,
typename
ComputeDataType
,
typename
Functor
,
int
broadcastDim
>
void
host_elementwise1D
(
HostTensorC
&
C
,
const
HostTensorA
&
A
,
const
HostTensorB
&
B
,
int
M
,
Functor
functor
)
{
for
(
int
m
=
0
;
m
<
M
;
++
m
)
{
ComputeDataType
Am
=
static_cast
<
ComputeDataType
>
(
A
(
m
));
ComputeDataType
Bm
=
static_cast
<
ComputeDataType
>
(
B
(
m
));
ComputeDataType
Cm
=
0
;
functor
(
Cm
,
Am
,
Bm
);
C
(
m
)
=
static_cast
<
ComputeDataType
>
(
Cm
);
}
}
int
main
()
{
bool
do_verification
=
true
;
bool
time_kernel
=
false
;
ck
::
index_t
M
=
1024
;
auto
f_host_tensor_descriptor1d
=
[](
std
::
size_t
len
,
std
::
size_t
stride
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
len
}),
std
::
vector
<
std
::
size_t
>
({
stride
}));
};
Tensor
<
ABDataType
>
a_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
ABDataType
>
b_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
Tensor
<
ABDataType
>
c_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
a_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
b_m
.
GenerateTensorValue
(
GeneratorTensor_3
<
ABDataType
>
{
0.0
,
1.0
});
DeviceMem
a_m_device_buf
(
sizeof
(
ABDataType
)
*
a_m
.
mDesc
.
GetElementSpace
());
DeviceMem
b_m_device_buf
(
sizeof
(
ABDataType
)
*
b_m
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_device_buf
(
sizeof
(
CDataType
)
*
c_m
.
mDesc
.
GetElementSpace
());
a_m_device_buf
.
ToDevice
(
a_m
.
mData
.
data
());
b_m_device_buf
.
ToDevice
(
b_m
.
mData
.
data
());
auto
broadcastAdd
=
DeviceElementwiseAddInstance
{};
auto
argument
=
broadcastAdd
.
MakeArgumentPointer
(
a_m_device_buf
.
GetDeviceBuffer
(),
b_m_device_buf
.
GetDeviceBuffer
(),
c_m_device_buf
.
GetDeviceBuffer
(),
{
M
},
{
1
},
{
1
},
{
1
},
Add
{},
256
);
if
(
!
broadcastAdd
.
IsSupportedArgument
(
argument
.
get
()))
{
throw
std
::
runtime_error
(
"The runtime parameters seems not supported by the "
"DeviceBinaryElementwise_2D instance, exiting!"
);
};
auto
broadcastAdd_invoker_ptr
=
broadcastAdd
.
MakeInvokerPointer
();
float
ave_time
=
broadcastAdd_invoker_ptr
->
Run
(
argument
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
std
::
cout
<<
"Perf: "
<<
ave_time
<<
" ms"
<<
std
::
endl
;
bool
pass
=
true
;
if
(
do_verification
)
{
c_m_device_buf
.
FromDevice
(
c_m
.
mData
.
data
());
Tensor
<
CDataType
>
host_c_m
(
f_host_tensor_descriptor1d
(
M
,
1
));
host_elementwise1D
<
Tensor
<
ABDataType
>
,
Tensor
<
ABDataType
>
,
Tensor
<
CDataType
>
,
EltwiseComputeDataType
,
Add
,
0
>
(
host_c_m
,
a_m
,
b_m
,
M
,
Add
{});
pass
&=
ck
::
utils
::
check_err
(
c_m
.
mData
,
host_c_m
.
mData
,
"Error: Incorrect results d1"
,
1e-3
,
1e-3
);
}
return
pass
?
0
:
1
;
}
example/
19
_cgemm/CMakeLists.txt
→
example/
20
_cgemm/CMakeLists.txt
View file @
b3767dbe
File moved
example/
19
_cgemm/cgemm_xdl_bf16.cpp
→
example/
20
_cgemm/cgemm_xdl_bf16.cpp
View file @
b3767dbe
File moved
example/CMakeLists.txt
View file @
b3767dbe
...
...
@@ -51,4 +51,5 @@ add_subdirectory(15_grouped_gemm)
add_subdirectory
(
16_gemm_reduce
)
add_subdirectory
(
17_convnd_bwd_data_xdl
)
add_subdirectory
(
18_batched_gemm_reduce
)
add_subdirectory
(
19_cgemm
)
add_subdirectory
(
19_binary_elementwise
)
add_subdirectory
(
20_cgemm
)
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
0 → 100644
View file @
b3767dbe
#pragma once
#include <iostream>
#include <vector>
#include "device.hpp"
#include "device_base.hpp"
#include "gridwise_binary_elementwise_1d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementwiseFunctor
,
index_t
Dim
,
index_t
ScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOperator
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
auto
MakeDescriptor_M0_1d
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
{
// 1d desc - [m]
const
auto
desc_m0
=
make_naive_tensor_descriptor
(
make_tuple
(
shape
[
0
]),
make_tuple
(
stride
[
0
]));
// pad
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
threadPerBlock
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0_2d
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
{
const
int
m
=
shape
[
0
];
const
int
n
=
shape
[
1
];
// 2d desc - [m, n]
const
auto
desc_m_n
=
make_naive_tensor_descriptor
(
make_tuple
(
m
,
n
),
make_tuple
(
stride
[
0
],
stride
[
1
]));
// 1d desc - [m * n]
const
auto
desc_m0
=
transform_tensor_descriptor
(
desc_m_n
,
make_tuple
(
make_merge_transform
(
make_tuple
(
m
,
n
))),
make_tuple
(
Sequence
<
0
,
1
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
// pad
const
auto
m0
=
desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
gridSize
*
threadPerBlock
*
ScalarPerVector
;
const
auto
pad
=
math
::
integer_least_multiple
(
m0
,
loop_step
)
-
m0
;
const
auto
desc_m0_pad
=
transform_tensor_descriptor
(
desc_m0
,
make_tuple
(
make_right_pad_transform
(
m0
,
pad
)),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
desc_m0_pad
;
}
static
auto
MakeDescriptor_M0
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride
,
index_t
gridSize
,
index_t
threadPerBlock
)
{
static_assert
(
Dim
==
1
||
Dim
==
2
,
"wrong! DeviceBinaryElementwise not support this dimension"
);
// TODO - 3D, 4D, 5D
if
constexpr
(
Dim
==
1
)
return
MakeDescriptor_M0_1d
(
shape
,
stride
,
gridSize
,
threadPerBlock
);
else
if
constexpr
(
Dim
==
2
)
return
MakeDescriptor_M0_2d
(
shape
,
stride
,
gridSize
,
threadPerBlock
);
else
return
make_naive_tensor_descriptor
(
make_tuple
(
0
),
make_tuple
(
0
));
}
using
GridDesc_M0
=
decltype
(
MakeDescriptor_M0
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
GridwiseBinEltwise
=
GridwiseBinaryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
ComputeDataType
,
GridDesc_M0
,
ElementwiseFunctor
,
ScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
std
::
vector
<
int
>&
shape
,
const
std
::
vector
<
int
>&
stride_a
,
const
std
::
vector
<
int
>&
stride_b
,
const
std
::
vector
<
int
>&
stride_c
,
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
:
p_a_
(
p_a
),
p_b_
(
p_b
),
p_c_
(
p_c
),
functor_
(
functor
),
threadPerBlock_
(
threadPerBlock
),
gridSize_
(
128
)
// FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_a
,
gridSize_
,
threadPerBlock_
);
b_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_b
,
gridSize_
,
threadPerBlock_
);
c_grid_desc_m0_
=
MakeDescriptor_M0
(
shape
,
stride_c
,
gridSize_
,
threadPerBlock_
);
}
const
ADataType
*
p_a_
;
const
BDataType
*
p_b_
;
CDataType
*
p_c_
;
GridDesc_M0
a_grid_desc_m0_
;
GridDesc_M0
b_grid_desc_m0_
;
GridDesc_M0
c_grid_desc_m0_
;
ElementwiseFunctor
functor_
;
index_t
threadPerBlock_
;
index_t
gridSize_
;
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
kernel
=
kernel_elementwise_1d
<
GridwiseBinEltwise
,
ADataType
,
BDataType
,
CDataType
,
GridDesc_M0
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
threadPerBlock_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
a_grid_desc_m0_
,
arg
.
b_grid_desc_m0_
,
arg
.
c_grid_desc_m0_
,
arg
.
functor_
);
return
elapsed_time
;
}
// polymorphic
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
pArg
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
if
(
pArg
==
nullptr
)
return
false
;
// m * n
const
auto
m0
=
pArg
->
c_grid_desc_m0_
.
GetLength
(
I0
);
if
(
m0
%
ScalarPerVector
!=
0
)
return
false
;
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
std
::
vector
<
int
>
shape
,
std
::
vector
<
int
>
stride_a
,
std
::
vector
<
int
>
stride_b
,
std
::
vector
<
int
>
stride_c
,
ElementwiseFunctor
functor
,
index_t
threadPerBlock
)
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
shape
,
stride_a
,
stride_b
,
stride_c
,
functor
,
threadPerBlock
);
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"ScalarPerVector = "
<<
ScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp
0 → 100644
View file @
b3767dbe
#pragma once
#include "data_type.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
binary_element_wise
{
struct
Add
{
__host__
__device__
constexpr
void
operator
()(
float
&
dst
,
const
float
&
src1
,
const
float
&
src2
)
const
{
dst
=
src1
+
src2
;
}
};
}
// namespace binary_element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp
0 → 100644
View file @
b3767dbe
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
GridwiseBinEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
>
__global__
void
kernel_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
GridwiseBinEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
a_grid_desc_m0
,
b_grid_desc_m0
,
c_grid_desc_m0
,
functor
);
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
GridDesc_M0
,
typename
ElementwiseFunctor
,
index_t
ScalarPerVector
>
struct
GridwiseBinaryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m0
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
ScalarPerVector
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
__host__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
ScalarPerVector
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
CDataType
*
__restrict__
p_c_global
,
const
GridDesc_M0
a_grid_desc_m0
,
const
GridDesc_M0
b_grid_desc_m0
,
const
GridDesc_M0
c_grid_desc_m0
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m0
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m0
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m0
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
ScalarPerVector
,
true
>
c_thread_buf
;
const
auto
thread_to_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m0
,
thread_to_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
GridDesc_M0
,
decltype
(
thread_desc_m0
),
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
ScalarPerVector
,
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m0
,
thread_to_global_offset
};
auto
c_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
CDataType
,
decltype
(
thread_desc_m0
),
GridDesc_M0
,
PassThrough
,
Sequence
<
ScalarPerVector
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
ScalarPerVector
,
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
c_grid_desc_m0
,
thread_to_global_offset
,
PassThrough
{}};
const
index_t
threadPerBlock
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
m0
=
c_grid_desc_m0
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
threadPerBlock
*
ScalarPerVector
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
m0
/
(
loop_step
);
do
{
// read and process ScalarPerVector elements
a_global_load
.
Run
(
a_grid_desc_m0
,
a_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m0
,
b_global_buf
,
thread_desc_m0
,
make_tuple
(
I0
),
b_thread_buf
);
static_for
<
0
,
ScalarPerVector
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m0
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
c_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}));
});
c_global_write
.
Run
(
thread_desc_m0
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
c_thread_buf
,
c_grid_desc_m0
,
c_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m0
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m0
,
loop_step_index
);
c_global_write
.
MoveDstSliceWindow
(
c_grid_desc_m0
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
include/ck/utility/get_id.hpp
View file @
b3767dbe
...
...
@@ -11,10 +11,14 @@ __host__ __device__ constexpr index_t get_warp_size()
__device__
index_t
get_thread_local_1d_id
()
{
return
threadIdx
.
x
;
}
__device__
index_t
get_thread_global_1d_id
()
{
return
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
}
__device__
index_t
get_warp_local_1d_id
()
{
return
threadIdx
.
x
/
get_warp_size
();
}
__device__
index_t
get_block_1d_id
()
{
return
blockIdx
.
x
;
}
__device__
index_t
get_grid_size
()
{
return
gridDim
.
x
;
}
__device__
index_t
get_block_size
()
{
return
blockDim
.
x
;
}
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment