Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
19c18624
Commit
19c18624
authored
Dec 16, 2022
by
fsx950223
Browse files
add multi embeddings support
parent
10c72ace
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
256 additions
and
200 deletions
+256
-200
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
..._sparse_embedding/sparse_embedding3_forward_layernorm.cpp
+180
-62
include/ck/tensor_operation/gpu/device/impl/device_sparse_embeddings_forward_layernorm.hpp
...evice/impl/device_sparse_embeddings_forward_layernorm.hpp
+31
-57
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
...gpu/grid/gridwise_sparse_embeddings_forward_layernorm.hpp
+45
-81
No files found.
example/36_sparse_embedding/sparse_embedding3_forward_layernorm.cpp
View file @
19c18624
...
...
@@ -9,7 +9,7 @@
#include <ctime>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding
3
_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_sparse_embedding
s
_forward_layernorm.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
...
...
@@ -18,13 +18,6 @@
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_sparse_embedding3_forward_layernorm.hpp"
// using EmbType = float;
// using IndexType = int64_t;
// using GammaDataType = float;
// using BetaDataType = float;
// using AccDataType = float;
// using OutType = float;
using
EmbType
=
ck
::
half_t
;
using
IndexType
=
int64_t
;
using
GammaDataType
=
ck
::
half_t
;
...
...
@@ -32,47 +25,172 @@ using BetaDataType = ck::half_t;
using
AccDataType
=
float
;
using
OutType
=
ck
::
half_t
;
// clang-format off
// BlockSize, DimClusterSize, RowClusterSize, DimPerBlock, RowPerBlock, DimThreadSize, RowVectorSize
using
DeviceInstance_fp32_e256
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
256
,
1
,
1
>
;
using
DeviceInstance_fp32_e512
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
512
,
1
,
1
>
;
using
DeviceInstance_fp32_e768
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
768
,
1
,
1
>
;
using
DeviceInstance_fp32_e1024
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1024
,
1
,
1
>
;
using
DeviceInstance_fp32_e1536
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1536
,
1
,
1
>
;
using
DeviceInstance_fp32_e2048
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
2048
,
1
,
4
>
;
using
DeviceInstance_fp32_e4096
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
4096
,
1
,
4
>
;
using
DeviceInstance_fp32_e8192
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
8192
,
1
,
4
>
;
using
DeviceInstance_fp32_e16384
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
16384
,
1
,
4
>
;
using
DeviceInstance_fp16_e256
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
256
,
1
,
1
>
;
using
DeviceInstance_fp16_e512
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
512
,
1
,
2
>
;
using
DeviceInstance_fp16_e768
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
768
,
1
,
1
>
;
using
DeviceInstance_fp16_e1024
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1024
,
1
,
2
>
;
using
DeviceInstance_fp16_e1536
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1536
,
1
,
2
>
;
using
DeviceInstance_fp16_e2048
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
2048
,
1
,
2
>
;
using
DeviceInstance_fp16_e4096
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
4096
,
1
,
8
>
;
using
DeviceInstance_fp16_e8192
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbedding3ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
8192
,
1
,
8
>
;
template
<
typename
emb_type
,
ck
::
index_t
dim
>
struct
emb_kernel
{};
template
<
>
struct
emb_kernel
<
float
,
256
>
{
using
kernel_type
=
DeviceInstance_fp32_e256
;
};
template
<
>
struct
emb_kernel
<
float
,
512
>
{
using
kernel_type
=
DeviceInstance_fp32_e512
;
};
template
<
>
struct
emb_kernel
<
float
,
768
>
{
using
kernel_type
=
DeviceInstance_fp32_e768
;
};
template
<
>
struct
emb_kernel
<
float
,
1024
>
{
using
kernel_type
=
DeviceInstance_fp32_e1024
;};
template
<
>
struct
emb_kernel
<
float
,
1536
>
{
using
kernel_type
=
DeviceInstance_fp32_e1536
;};
template
<
>
struct
emb_kernel
<
float
,
2048
>
{
using
kernel_type
=
DeviceInstance_fp32_e2048
;};
template
<
>
struct
emb_kernel
<
float
,
4096
>
{
using
kernel_type
=
DeviceInstance_fp32_e4096
;};
template
<
>
struct
emb_kernel
<
float
,
8192
>
{
using
kernel_type
=
DeviceInstance_fp32_e8192
;};
template
<
>
struct
emb_kernel
<
float
,
16384
>
{
using
kernel_type
=
DeviceInstance_fp32_e16384
;};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
256
>
{
using
kernel_type
=
DeviceInstance_fp16_e256
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
512
>
{
using
kernel_type
=
DeviceInstance_fp16_e512
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
768
>
{
using
kernel_type
=
DeviceInstance_fp16_e768
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
1024
>
{
using
kernel_type
=
DeviceInstance_fp16_e1024
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
1536
>
{
using
kernel_type
=
DeviceInstance_fp16_e1536
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
2048
>
{
using
kernel_type
=
DeviceInstance_fp16_e2048
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
4096
>
{
using
kernel_type
=
DeviceInstance_fp16_e4096
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
8192
>
{
using
kernel_type
=
DeviceInstance_fp16_e8192
;
};
using
DeviceInstance_fp16_e256
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
256
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e512
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
512
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e768
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
768
,
1
,
1
,
3
>
;
using
DeviceInstance_fp16_e1024
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1024
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e1536
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
1536
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e2048
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
2048
,
1
,
2
,
3
>
;
using
DeviceInstance_fp16_e4096
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
4096
,
1
,
8
,
3
>
;
using
DeviceInstance_fp16_e8192
=
ck
::
tensor_operation
::
device
::
DeviceSparseEmbeddingsForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
256
,
1
,
256
,
1
,
8192
,
1
,
8
,
3
>
;
template
<
typename
emb_type
,
ck
::
index_t
dim
>
struct
emb_kernel
{
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
256
>
{
using
kernel_type
=
DeviceInstance_fp16_e256
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
512
>
{
using
kernel_type
=
DeviceInstance_fp16_e512
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
768
>
{
using
kernel_type
=
DeviceInstance_fp16_e768
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
1024
>
{
using
kernel_type
=
DeviceInstance_fp16_e1024
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
1536
>
{
using
kernel_type
=
DeviceInstance_fp16_e1536
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
2048
>
{
using
kernel_type
=
DeviceInstance_fp16_e2048
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
4096
>
{
using
kernel_type
=
DeviceInstance_fp16_e4096
;
};
template
<
>
struct
emb_kernel
<
ck
::
half_t
,
8192
>
{
using
kernel_type
=
DeviceInstance_fp16_e8192
;
};
// clang-format on
...
...
@@ -152,19 +270,19 @@ int main()
beta_dev
.
ToDevice
(
beta
.
mData
.
data
());
auto
device_instance
=
typename
emb_kernel
<
EmbType
,
current_dim
>::
kernel_type
{};
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
out_dev
.
GetDeviceBuffer
(),
emb_a
_dev
.
GetDeviceBuffer
(),
emb_
b
_dev
.
GetDeviceBuffer
(),
emb_
c
_dev
.
GetDeviceBuffer
(),
index_a
_dev
.
GetDeviceBuffer
(),
index_
b
_dev
.
GetDeviceBuffer
(),
index_
c
_dev
.
GetDeviceBuffer
(),
gamma
_dev
.
GetDeviceBuffer
(),
bet
a_dev
.
GetDeviceBuffer
(),
num_rows
,
current_dim
,
index_length
,
epsilon
);
auto
argument_ptr
=
device_instance
.
MakeArgumentPointer
(
out
_dev
.
GetDeviceBuffer
(),
{
ck
::
type_convert
<
EmbType
*>
(
emb_
a
_dev
.
GetDeviceBuffer
()
)
,
ck
::
type_convert
<
EmbType
*>
(
emb_
b
_dev
.
GetDeviceBuffer
()
)
,
ck
::
type_convert
<
EmbType
*>
(
emb_c
_dev
.
GetDeviceBuffer
()
)}
,
{
ck
::
type_convert
<
IndexType
*>
(
index_
a
_dev
.
GetDeviceBuffer
()
)
,
ck
::
type_convert
<
IndexType
*>
(
index_
b
_dev
.
GetDeviceBuffer
()
)
,
ck
::
type_convert
<
IndexType
*>
(
index_c
_dev
.
GetDeviceBuffer
()
)}
,
gamm
a_dev
.
GetDeviceBuffer
(),
beta_dev
.
GetDeviceBuffer
()
,
current_dim
,
index_length
,
epsilon
);
std
::
cout
<<
"Dim:"
<<
current_dim
<<
", kernel:"
<<
device_instance
.
GetTypeString
()
<<
std
::
endl
<<
std
::
flush
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/device/impl/device_sparse_embedding
s
_forward_layernorm.hpp
View file @
19c18624
...
...
@@ -12,7 +12,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp"
namespace
ck
{
namespace
tensor_operation
{
...
...
@@ -30,10 +30,10 @@ template <typename EmbType,
ck
::
index_t
DimPerBlock
,
ck
::
index_t
RowPerBlock
,
ck
::
index_t
DimThreadSize
,
ck
::
index_t
RowVectorSize
>
struct
DeviceSparseEmbedding3ForwardLayernorm
:
public
BaseOperator
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
DeviceSparseEmbeddingsForwardLayernorm
:
public
BaseOperator
{
static
auto
MakeOutputDescriptor
(
const
index_t
index_length
,
const
index_t
rows
)
{
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
index_length
,
rows
));
...
...
@@ -42,28 +42,18 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
struct
Argument
:
public
BaseArgument
{
Argument
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
ck
::
index_t
NumRows
,
const
ck
::
index_t
EmbeddingDim
,
const
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
:
p_out_
(
p_out
),
p_emb_a_
(
p_emb_a
),
p_emb_b_
(
p_emb_b
),
p_emb_c_
(
p_emb_c
),
p_index_a_
(
p_index_a
),
p_index_b_
(
p_index_b
),
p_index_c_
(
p_index_c
),
p_embs_
(
p_embs
),
p_indexs_
(
p_indexs
),
p_gamma_
(
p_gamma
),
p_beta_
(
p_beta
),
NumRows_
(
NumRows
),
EmbeddingDim_
(
EmbeddingDim
),
IndexLength_
(
IndexLength
),
epsilon_
(
epsilon
)
...
...
@@ -72,15 +62,10 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
}
OutType
*
p_out_
;
const
EmbType
*
p_emb_a_
;
const
EmbType
*
p_emb_b_
;
const
EmbType
*
p_emb_c_
;
const
IndexType
*
p_index_a_
;
const
IndexType
*
p_index_b_
;
const
IndexType
*
p_index_c_
;
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs_
;
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexs_
;
const
GammaDataType
*
p_gamma_
;
const
BetaDataType
*
p_beta_
;
ck
::
index_t
NumRows_
;
ck
::
index_t
EmbeddingDim_
;
ck
::
index_t
IndexLength_
;
AccDataType
epsilon_
;
...
...
@@ -88,37 +73,28 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
size_t
grid_size_
;
};
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
void
*
p_emb_a
,
const
void
*
p_emb_b
,
const
void
*
p_emb_c
,
const
void
*
p_index_a
,
const
void
*
p_index_b
,
const
void
*
p_index_c
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
NumRows
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
void
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>&
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>&
p_indexs
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
ck
::
index_t
EmbeddingDim
,
ck
::
index_t
IndexLength
,
const
AccDataType
epsilon
)
{
return
std
::
make_unique
<
Argument
>
(
reinterpret_cast
<
OutType
*>
(
p_out
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_a
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_b
),
reinterpret_cast
<
const
EmbType
*>
(
p_emb_c
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_a
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_b
),
reinterpret_cast
<
const
IndexType
*>
(
p_index_c
),
p_embs
,
p_indexs
,
reinterpret_cast
<
const
GammaDataType
*>
(
p_gamma
),
reinterpret_cast
<
const
BetaDataType
*>
(
p_beta
),
NumRows
,
EmbeddingDim
,
IndexLength
,
epsilon
);
}
using
GridwiseSparseEmbedding
=
GridwiseSparseEmbedding
3
ForwardLayernorm
<
EmbType
,
GridwiseSparseEmbedding
s
ForwardLayernorm
<
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
...
...
@@ -131,7 +107,8 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
DimPerBlock
,
RowPerBlock
,
DimThreadSize
,
RowVectorSize
>
;
RowVectorSize
,
NumEmbeddings
>
;
struct
Invoker
:
public
BaseInvoker
{
...
...
@@ -139,14 +116,15 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
{
auto
out_desc
=
MakeOutputDescriptor
(
arg
.
IndexLength_
,
arg
.
EmbeddingDim_
);
const
auto
kernel_main
=
kernel_sparse_embedding
3
_forward_layernorm
<
GridwiseSparseEmbedding
,
kernel_sparse_embedding
s
_forward_layernorm
<
GridwiseSparseEmbedding
,
EmbType
,
IndexType
,
GammaDataType
,
BetaDataType
,
AccDataType
,
OutType
,
decltype
(
out_desc
)
>
;
decltype
(
out_desc
),
NumEmbeddings
>
;
float
avg_time
=
0
;
avg_time
+=
launch_and_time_kernel
(
stream_config
,
kernel_main
,
...
...
@@ -154,12 +132,8 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
dim3
(
BlockSize
),
0
,
arg
.
p_out_
,
arg
.
p_emb_a_
,
arg
.
p_emb_b_
,
arg
.
p_emb_c_
,
arg
.
p_index_a_
,
arg
.
p_index_b_
,
arg
.
p_index_c_
,
arg
.
p_embs_
,
arg
.
p_indexs_
,
arg
.
p_gamma_
,
arg
.
p_beta_
,
out_desc
,
...
...
@@ -177,7 +151,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
static
bool
IsSupportedArgument
(
const
Argument
*
p_arg
)
{
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
)
&&
(
p_arg
->
NumRows_
%
DimPerBlock
==
0
)
;
return
(
RowPerBlock
==
p_arg
->
EmbeddingDim_
);
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
...
...
@@ -195,7 +169,7 @@ struct DeviceSparseEmbedding3ForwardLayernorm : public BaseOperator
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceSparseEmbedding
3
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
str
<<
"DeviceSparseEmbedding
s
ForwardLayernorm_"
<<
BlockSize
<<
"_"
<<
DimClusterSize
<<
"x"
<<
RowClusterSize
<<
"_"
<<
DimPerBlock
<<
"x"
<<
RowPerBlock
<<
"_"
<<
DimThreadSize
<<
"x"
<<
RowVectorSize
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
3
_forward_layernorm.hpp
→
include/ck/tensor_operation/gpu/grid/gridwise_sparse_embedding
s
_forward_layernorm.hpp
View file @
19c18624
...
...
@@ -17,33 +17,21 @@ template <typename GridwiseSparseEmbedding,
typename
BetaDataType
,
typename
AccDataType
,
typename
OutType
,
typename
OutGridDesc
>
typename
OutGridDesc
,
ck
::
index_t
NumEmbeddings
>
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
__global__
void
kernel_sparse_embedding3_forward_layernorm
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
__global__
void
kernel_sparse_embeddings_forward_layernorm
(
OutType
*
p_out
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
out_grid_desc
,
const
AccDataType
epsilon
)
{
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_emb_a
,
p_emb_b
,
p_emb_c
,
p_index_a
,
p_index_b
,
p_index_c
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
GridwiseSparseEmbedding
::
Run
(
p_out
,
p_embs
,
p_indexes
,
p_gamma
,
p_beta
,
out_grid_desc
,
epsilon
);
}
template
<
typename
EmbType
,
...
...
@@ -59,8 +47,9 @@ template <typename EmbType,
ck
::
index_t
DimPerBlock
,
// Row x Dim, along Dim
ck
::
index_t
RowPerBlock
,
// Row x Dim, along Row
ck
::
index_t
DimThreadSize
,
// this is actually not vector, but number of registers
ck
::
index_t
RowVectorSize
>
struct
GridwiseSparseEmbedding3ForwardLayernorm
ck
::
index_t
RowVectorSize
,
ck
::
index_t
NumEmbeddings
>
struct
GridwiseSparseEmbeddingsForwardLayernorm
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -85,8 +74,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
using
ThreadwiseWolfordDesc2D
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{},
Number
<
RowSubBlocks
*
RowVectorSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWolfordDescReduce
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
DimSubBlocks
*
DimThreadSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelford
<
AccDataType
,
ThreadwiseWolfordDesc2D
,
ThreadwiseWolfordDescReduce
>
;
...
...
@@ -97,12 +86,8 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
BlockwiseWelford
<
AccDataType
,
BlockSize
,
ThreadClusterLength
,
Sequence
<
0
,
1
>>
;
__device__
static
void
Run
(
OutType
*
p_out
,
const
EmbType
*
p_emb_a
,
const
EmbType
*
p_emb_b
,
const
EmbType
*
p_emb_c
,
const
IndexType
*
p_index_a
,
const
IndexType
*
p_index_b
,
const
IndexType
*
p_index_c
,
const
ck
::
Array
<
EmbType
*
,
NumEmbeddings
>
p_embs
,
const
ck
::
Array
<
IndexType
*
,
NumEmbeddings
>
p_indexes
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
const
OutGridDesc
,
...
...
@@ -111,9 +96,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
// const auto index_length = out_grid_desc.GetLength(I0);
// const auto emb_dim = out_grid_desc.GetLength(I1);
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
Sequence
<
DimClusterSize
,
RowClusterSize
>
{},
Sequence
<
0
,
1
>
{});
...
...
@@ -141,13 +123,11 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
constexpr
auto
gamma_beta_buf_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
RowSubBlocks
,
RowVectorSize
));
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
in_thread_buf_c
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_a
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_b
;
StaticBuffer
<
AddressSpaceEnum
::
Sgpr
,
IndexType
,
DimPerBlock
,
true
>
index_buf_c
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
EmbType
,
thread_buf_size
,
true
>
,
NumEmbeddings
>
in_thread_bufs
;
ck
::
Array
<
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
IndexType
,
DimPerBlock
,
true
>
,
NumEmbeddings
>
index_bufs
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
thread_buf_size
,
true
>
acc_thread_buf
;
...
...
@@ -160,42 +140,30 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
mean_var_buf_size
,
true
>
var_thread_buf
;
auto
load_current_sub_row
=
[
&
](
auto
i_dim_sub_
,
auto
i_row_sub_
)
{
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_a
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_b
;
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
emb_vector_c
;
using
src_vector_t
=
typename
decltype
(
emb_vector_a
)
::
type
;
ck
::
Array
<
vector_type_maker_t
<
EmbType
,
RowVectorSize
>
,
NumEmbeddings
>
emb_vectors
;
auto
emb_a
=
emb_vectors
[
0
];
using
src_vector_t
=
typename
decltype
(
emb_a
)
::
type
;
static_for
<
0
,
DimThreadSize
,
1
>
{}([
&
](
auto
i_dim_vec_
)
{
constexpr
auto
current_dim
=
i_dim_sub_
*
DimPerSubBlock
+
i_dim_vec_
;
IndexType
index_a
=
index_buf_a
[
Number
<
current_dim
>
{}];
IndexType
index_b
=
index_buf_b
[
Number
<
current_dim
>
{}];
IndexType
index_c
=
index_buf_c
[
Number
<
current_dim
>
{}];
auto
thread_offset
=
(
thread_row_cluster_id
+
i_row_sub_
*
RowClusterSize
)
*
sizeof
(
EmbType
)
*
RowVectorSize
;
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
IndexType
index
=
index_bufs
[
i_embedding_
.
value
][
Number
<
current_dim
>
{}];
int32x4_t
emb_res_a
=
make_wave_buffer_resource_with_default_range
(
p_emb_a
+
index_a
*
RowPerBlock
);
int32x4_t
emb_res_b
=
make_wave_buffer_resource_with_default_range
(
p_emb_b
+
index_b
*
RowPerBlock
);
int32x4_t
emb_res_c
=
make_wave_buffer_resource_with_default_range
(
p_emb_c
+
index_c
*
RowPerBlock
);
emb_vector_a
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_a
,
thread_offset
,
0
);
emb_vector_b
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_b
,
thread_offset
,
0
);
emb_vector_c
.
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res_c
,
thread_offset
,
0
);
int32x4_t
emb_res
=
make_wave_buffer_resource_with_default_range
(
p_embs
[
i_embedding_
.
value
]
+
index
*
RowPerBlock
);
emb_vectors
(
i_embedding_
.
value
).
template
AsType
<
src_vector_t
>()(
I0
)
=
amd_buffer_load_impl
<
EmbType
,
RowVectorSize
>
(
emb_res
,
thread_offset
,
0
);
});
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
in_thread_buf_a
(
Number
<
register_offset
>
{})
=
emb_vector_a
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_b
(
Number
<
register_offset
>
{})
=
emb_vector_b
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
in_thread_buf_c
(
Number
<
register_offset
>
{})
=
emb_vector_c
.
template
AsType
<
EmbType
>()[
i_row_vec_
];
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
in_thread_bufs
(
i_embedding_
.
value
)(
Number
<
register_offset
>
{})
=
emb_vectors
[
i_embedding_
.
value
].
template
AsType
<
EmbType
>()[
i_row_vec_
];
});
});
});
};
...
...
@@ -205,14 +173,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
RowVectorSize
,
1
>
{}([
&
](
auto
i_row_vec_
)
{
constexpr
auto
register_offset
=
thread_buf_desc
.
CalculateOffset
(
make_tuple
(
i_dim_sub_
,
i_dim_vec_
,
i_row_sub_
,
i_row_vec_
));
AccDataType
va
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_a
(
Number
<
register_offset
>
{}));
AccDataType
vb
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_b
(
Number
<
register_offset
>
{}));
AccDataType
vc
=
ck
::
type_convert
<
AccDataType
>
(
in_thread_buf_c
(
Number
<
register_offset
>
{}));
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
va
+
vb
+
vc
;
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
acc_thread_buf
(
Number
<
register_offset
>
{})
+=
ck
::
type_convert
<
AccDataType
>
(
in_thread_bufs
(
i_embedding_
.
value
)(
Number
<
register_offset
>
{}));
});
});
});
};
...
...
@@ -273,9 +237,10 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
// first load index
ck
::
static_for
<
0
,
DimPerBlock
,
1
>
{}([
&
](
auto
i_idx_
)
{
// prefer use s_load
index_buf_a
(
i_idx_
)
=
p_index_a
[
index_start
+
i_idx_
.
value
];
index_buf_b
(
i_idx_
)
=
p_index_b
[
index_start
+
i_idx_
.
value
];
index_buf_c
(
i_idx_
)
=
p_index_c
[
index_start
+
i_idx_
.
value
];
ck
::
static_for
<
0
,
NumEmbeddings
,
1
>
{}([
&
](
auto
i_embedding_
)
{
index_bufs
(
i_embedding_
.
value
)(
i_idx_
)
=
p_indexes
[
i_embedding_
.
value
][
index_start
+
i_idx_
.
value
];
});
});
// load gamma/beta
...
...
@@ -329,7 +294,6 @@ struct GridwiseSparseEmbedding3ForwardLayernorm
static_for
<
0
,
mean_var_buf_size
,
1
>
{}([
&
](
auto
I
)
{
if
constexpr
(
I
>
0
)
block_sync_lds
();
BlockwiseWelford
::
Run
(
mean_thread_buf
(
I
),
var_thread_buf
(
I
),
threadwise_welford
.
cur_count_
);
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment