"docs/git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "56e2835eb4dc13aadae86ee2de8ecb8415579424"
Commit cfce1f11 authored by rocking's avatar rocking
Browse files

Support different YVectorDim in GridwiseLayernorm

parent 53607c7d
...@@ -122,6 +122,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -122,6 +122,7 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorSize, BetaSrcVectorSize,
XSrcVectorDim,
YDstVectorSize, YDstVectorSize,
false>; false>;
...@@ -142,6 +143,7 @@ struct DeviceLayernorm : public BaseOperator ...@@ -142,6 +143,7 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize, XSrcVectorSize,
GammaSrcVectorSize, GammaSrcVectorSize,
BetaSrcVectorSize, BetaSrcVectorSize,
XSrcVectorDim,
YDstVectorSize, YDstVectorSize,
true>; true>;
......
...@@ -65,13 +65,17 @@ template <typename XDataType, ...@@ -65,13 +65,17 @@ template <typename XDataType,
index_t XSrcVectorSize, index_t XSrcVectorSize,
index_t GammaSrcVectorSize, index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize, index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize, index_t YDstVectorSize,
bool SweepOnce> bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk struct GridwiseLayernorm_mk_to_mk
{ {
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) || static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)) && (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
(KThreadSliceSize % YDstVectorSize == 0), "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"); "Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0); static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
...@@ -231,7 +235,7 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -231,7 +235,7 @@ struct GridwiseLayernorm_mk_to_mk
AccElementwiseOperation, AccElementwiseOperation,
ThreadBufferLengths_M_K, ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder, ThreadBufferDimAccessOrder,
XSrcVectorDim, YDstVectorDim,
YDstVectorSize, YDstVectorSize,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
1, 1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment