Commit cfce1f11 authored by rocking's avatar rocking
Browse files

Support different YVectorDim in GridwiseLayernorm

parent 53607c7d
......@@ -122,6 +122,7 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize,
GammaSrcVectorSize,
BetaSrcVectorSize,
XSrcVectorDim,
YDstVectorSize,
false>;
......@@ -142,6 +143,7 @@ struct DeviceLayernorm : public BaseOperator
XSrcVectorSize,
GammaSrcVectorSize,
BetaSrcVectorSize,
XSrcVectorDim,
YDstVectorSize,
true>;
......
......@@ -65,13 +65,17 @@ template <typename XDataType,
index_t XSrcVectorSize,
index_t GammaSrcVectorSize,
index_t BetaSrcVectorSize,
index_t YDstVectorDim,
index_t YDstVectorSize,
bool SweepOnce>
struct GridwiseLayernorm_mk_to_mk
{
static_assert(((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0)) &&
(KThreadSliceSize % YDstVectorSize == 0),
static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
(XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
(YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
"Invalid thread slice sizes and/or vector sizes configuration, please check!");
static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
......@@ -231,7 +235,7 @@ struct GridwiseLayernorm_mk_to_mk
AccElementwiseOperation,
ThreadBufferLengths_M_K,
ThreadBufferDimAccessOrder,
XSrcVectorDim,
YDstVectorDim,
YDstVectorSize,
InMemoryDataOperationEnum::Set,
1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment