Commit 6582aedc authored by binmakeswell's avatar binmakeswell
Browse files

fix format (#583)

parent f08fc17f
#include <torch/extension.h> #include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cub/cub.cuh> #include <torch/extension.h>
#include "block_reduce.h"
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
for (int idx = 0; idx + tps < cols; idx += bpack_size) { T pack[pack_size];
BlockLoad(ts_load).Load(src_row + idx, pack); for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockStore(ts_store).Store(dst_row + idx, pack); BlockLoad(ts_load).Load(src_row + idx, pack);
} BlockStore(ts_store).Store(dst_row + idx, pack);
}
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
for (int idx = 0; idx + tps < cols; idx += bpack_size) { T pack[pack_size];
BlockLoad(ts_load).Load(dst_row + idx, pack); for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockStore(ts_store).Store(src_row + idx, pack); BlockLoad(ts_load).Load(dst_row + idx, pack);
} BlockStore(ts_store).Store(src_row + idx, pack);
}
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; T pack[pack_size]; int tps = threadIdx.x * pack_size;
for (int idx = 0; idx + tps < cols; idx += bpack_size) { T pack[pack_size];
BlockLoad(ts_load).Load(src_row + idx, pack); for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockStore(ts_store).Store(dst_row1 + idx, pack); BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row2 + idx, pack); BlockStore(ts_store).Store(dst_row1 + idx, pack);
} BlockStore(ts_store).Store(dst_row2 + idx, pack);
}
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, const int cols) { __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, assert(cols % pack_size == 0);
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; const int bpack_size = block_size * pack_size;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockLoad;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockLoad::TempStorage ts_load;
int tps = threadIdx.x * pack_size; typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
T pack1[pack_size], pack2[pack_size]; BlockStore;
for (int idx = 0; idx + tps < cols; idx += bpack_size) { __shared__ typename BlockStore::TempStorage ts_store;
BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
#pragma unroll int tps = threadIdx.x * pack_size;
for (int i = 0; i < pack_size; ++i) { T pack1[pack_size], pack2[pack_size];
pack1[i] += pack2[i]; for (int idx = 0; idx + tps < cols; idx += bpack_size) {
} BlockLoad(ts_load).Load(dst_row1 + idx, pack1);
BlockLoad(ts_load).Load(dst_row2 + idx, pack2);
BlockStore(ts_store).Store(src_row + idx, pack1); #pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] += pack2[i];
} }
}
template<typename T, int block_size, int pack_size> BlockStore(ts_store).Store(src_row + idx, pack1);
__device__ void moe_cb_one_fwd( }
T *src_row, T *dst_row, }
const T weight, const int cols) {
assert(cols % pack_size == 0); template <typename T, int block_size, int pack_size>
const int bpack_size = block_size * pack_size; __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) {
typedef cub::BlockLoad<T, block_size, pack_size, assert(cols % pack_size == 0);
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; const int bpack_size = block_size * pack_size;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockLoad;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockLoad::TempStorage ts_load;
int tps = threadIdx.x * pack_size; T pack[pack_size]; typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
for (int idx = 0; idx + tps < cols; idx += bpack_size) { BlockStore;
BlockLoad(ts_load).Load(src_row + idx, pack); __shared__ typename BlockStore::TempStorage ts_store;
#pragma unroll int tps = threadIdx.x * pack_size;
for (int i = 0; i < pack_size; ++i) { T pack[pack_size];
pack[i] *= weight; for (int idx = 0; idx + tps < cols; idx += bpack_size) {
} BlockLoad(ts_load).Load(src_row + idx, pack);
BlockStore(ts_store).Store(dst_row + idx, pack); #pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] *= weight;
} }
}
template<typename T, int block_size, int pack_size> BlockStore(ts_store).Store(dst_row + idx, pack);
__device__ void moe_cb_one_bwd( }
T *src_row, T *dst_row, T *tks_row, T *weight_grad, }
const T weight, const int cols) {
assert(cols % pack_size == 0); template <typename T, int block_size, int pack_size>
const int bpack_size = block_size * pack_size; __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) {
typedef cub::BlockLoad<T, block_size, pack_size, assert(cols % pack_size == 0);
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; const int bpack_size = block_size * pack_size;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockLoad;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockLoad::TempStorage ts_load;
int tps = threadIdx.x * pack_size; typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
T grad[pack_size], tokens[pack_size]; BlockStore;
float thread_sum = 0; __shared__ typename BlockStore::TempStorage ts_store;
for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
#pragma unroll int tps = threadIdx.x * pack_size;
for (int i = 0; i < pack_size; ++i) { T grad[pack_size], tokens[pack_size];
thread_sum += grad[i] * tokens[i]; float thread_sum = 0;
grad[i] *= weight; for (int idx = 0; idx + tps < cols; idx += bpack_size) {
} BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row + idx, tokens);
BlockStore(ts_store).Store(src_row + idx, grad); #pragma unroll
for (int i = 0; i < pack_size; ++i) {
thread_sum += grad[i] * tokens[i];
grad[i] *= weight;
} }
blockReduce<ReduceType::kSum, 1>(&thread_sum); BlockStore(ts_store).Store(src_row + idx, grad);
}
if (threadIdx.x == 0) blockReduce<ReduceType::kSum, 1>(&thread_sum);
*weight_grad = static_cast<T>(thread_sum);
}
template<typename T, int block_size, int pack_size> if (threadIdx.x == 0)
__device__ void moe_cb_two_fwd( *weight_grad = static_cast<T>(thread_sum);
T *src_row1, T *src_row2, T *dst_row, }
const T weight1, const T weight2, const int cols) {
assert(cols % pack_size == 0); template <typename T, int block_size, int pack_size>
const int bpack_size = block_size * pack_size; __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2,
const int cols) {
typedef cub::BlockLoad<T, block_size, pack_size, assert(cols % pack_size == 0);
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; const int bpack_size = block_size * pack_size;
__shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockLoad;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockLoad::TempStorage ts_load;
int tps = threadIdx.x * pack_size; typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
T pack1[pack_size], pack2[pack_size]; BlockStore;
for (int idx = 0; idx + tps < cols; idx += bpack_size) { __shared__ typename BlockStore::TempStorage ts_store;
BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
#pragma unroll int tps = threadIdx.x * pack_size;
for (int i = 0; i < pack_size; ++i) { T pack1[pack_size], pack2[pack_size];
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2; for (int idx = 0; idx + tps < cols; idx += bpack_size) {
} BlockLoad(ts_load).Load(src_row1 + idx, pack1);
BlockLoad(ts_load).Load(src_row2 + idx, pack2);
BlockStore(ts_store).Store(dst_row + idx, pack1); #pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack1[i] = pack1[i] * weight1 + pack2[i] * weight2;
} }
BlockStore(ts_store).Store(dst_row + idx, pack1);
}
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_bwd( __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, T *tks_row1, T *tks_row2, T *weight_grad1,
T *tks_row1, T *tks_row2, T *weight_grad1, T *weight_grad2, T *weight_grad2, const T weight1,
const T weight1, const T weight2, const int cols) { const T weight2, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
typedef cub::BlockLoad<T, block_size, pack_size, typedef cub::BlockLoad<T, block_size, pack_size, cub::BLOCK_LOAD_VECTORIZE>
cub::BLOCK_LOAD_VECTORIZE> BlockLoad; BlockLoad;
__shared__ typename BlockLoad::TempStorage ts_load; __shared__ typename BlockLoad::TempStorage ts_load;
typedef cub::BlockStore<T, block_size, pack_size, typedef cub::BlockStore<T, block_size, pack_size, cub::BLOCK_STORE_VECTORIZE>
cub::BLOCK_STORE_VECTORIZE> BlockStore; BlockStore;
__shared__ typename BlockStore::TempStorage ts_store; __shared__ typename BlockStore::TempStorage ts_store;
int tps = threadIdx.x * pack_size; int tps = threadIdx.x * pack_size;
T grad[pack_size], tokens1[pack_size], tokens2[pack_size], T grad[pack_size], tokens1[pack_size], tokens2[pack_size], sgrad1[pack_size],
sgrad1[pack_size], sgrad2[pack_size]; sgrad2[pack_size];
float thread_sum[2] = {0, 0}; float thread_sum[2] = {0, 0};
for (int idx = 0; idx + tps < cols; idx += bpack_size) { for (int idx = 0; idx + tps < cols; idx += bpack_size) {
BlockLoad(ts_load).Load(dst_row + idx, grad); BlockLoad(ts_load).Load(dst_row + idx, grad);
BlockLoad(ts_load).Load(tks_row1 + idx, tokens1); BlockLoad(ts_load).Load(tks_row1 + idx, tokens1);
BlockLoad(ts_load).Load(tks_row2 + idx, tokens2); BlockLoad(ts_load).Load(tks_row2 + idx, tokens2);
#pragma unroll #pragma unroll
for (int i = 0; i < pack_size; ++i) { for (int i = 0; i < pack_size; ++i) {
thread_sum[0] += grad[i] * tokens1[i]; thread_sum[0] += grad[i] * tokens1[i];
thread_sum[1] += grad[i] * tokens2[i]; thread_sum[1] += grad[i] * tokens2[i];
sgrad1[i] = weight1 * grad[i]; sgrad1[i] = weight1 * grad[i];
sgrad2[i] = weight2 * grad[i]; sgrad2[i] = weight2 * grad[i];
}
BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
} }
blockReduce<ReduceType::kSum, 2>(thread_sum); BlockStore(ts_store).Store(src_row1 + idx, sgrad1);
BlockStore(ts_store).Store(src_row2 + idx, sgrad2);
}
if (threadIdx.x == 0) blockReduce<ReduceType::kSum, 2>(thread_sum);
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
if (threadIdx.x == 0)
*weight_grad1 = static_cast<T>(thread_sum[0]);
else if (threadIdx.x == 1)
*weight_grad2 = static_cast<T>(thread_sum[1]);
} }
// DISPATCH KERNELS -------------------------------- // DISPATCH KERNELS --------------------------------
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector( __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
T *src_row, T *dst_row1, T *dst_row2, const int cols, const int cols, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>( moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
src_row, dst_row1, dst_row2, cols); cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_dpch_one_fwd<T, block_size, pack_size>( moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row1, cols);
src_row, dst_row1, cols); else if (indicator2 != 0)
else if (indicator2 != 0) moe_dpch_one_fwd<T, block_size, pack_size>(src_row, dst_row2, cols);
moe_dpch_one_fwd<T, block_size, pack_size>( else
src_row, dst_row2, cols); return;
else
return;
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector( __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
T *src_row, T *dst_row1, T *dst_row2, const int cols, const int cols, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>( moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
src_row, dst_row1, dst_row2, cols); cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_dpch_one_bwd<T, block_size, pack_size>( moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row1, cols);
src_row, dst_row1, cols); else if (indicator2 != 0)
else if (indicator2 != 0) moe_dpch_one_bwd<T, block_size, pack_size>(src_row, dst_row2, cols);
moe_dpch_one_bwd<T, block_size, pack_size>( else
src_row, dst_row2, cols); return;
else
return;
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel( __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
T *batch_tokens, T *expert_input, int *mask1, int *mask2, int *dest1,
int *mask1, int *mask2, int *dest2, const int h) {
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_fwd_selector<T, block_size, pack_size>(
moe_dpch_fwd_selector<T, block_size, pack_size>( batch_tokens + (row * h), expert_input + (dest1[row] * h),
batch_tokens + (row * h), expert_input + (dest2[row] * h), h, mask1[row], indicator2);
expert_input + (dest1[row] * h), expert_input + (dest2[row] * h),
h, mask1[row], indicator2);
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel( __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
T *tokens_grad, T *expert_grad, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2, const int h) {
int row = blockIdx.x;
int row = blockIdx.x; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_dpch_bwd_selector<T, block_size, pack_size>(
moe_dpch_bwd_selector<T, block_size, pack_size>( tokens_grad + (row * h), expert_grad + (dest1[row] * h),
tokens_grad + (row * h), expert_grad + (dest2[row] * h), h, mask1[row], indicator2);
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
h, mask1[row], indicator2);
} }
// COMBINE KERNELS -------------------------------- // COMBINE KERNELS --------------------------------
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_fwd_selector( __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, const int cols, const int cols, const T weight1,
const T weight1, const T weight2, const T weight2, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>( moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
src_row1, src_row2, dst_row, weight1, weight2, cols); weight1, weight2, cols);
else if (indicator1 != 0) else if (indicator1 != 0)
moe_cb_one_fwd<T, block_size, pack_size>( moe_cb_one_fwd<T, block_size, pack_size>(src_row1, dst_row, weight1, cols);
src_row1, dst_row, weight1, cols); else if (indicator2 != 0)
else if (indicator2 != 0) moe_cb_one_fwd<T, block_size, pack_size>(src_row2, dst_row, weight2, cols);
moe_cb_one_fwd<T, block_size, pack_size>( else
src_row2, dst_row, weight2, cols); return;
else
return;
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_bwd_selector( __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *src_row1, T *src_row2, T *dst_row, const int cols, const int cols, T *tks_row1, T *tks_row2,
T *tks_row1, T *tks_row2, T *wt_grad1, T *wt_grad2, T *wt_grad1, T *wt_grad2, const T weight1,
const T weight1, const T weight2, const T weight2, const int indicator1,
const int indicator1, const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>( moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
src_row1, src_row2, dst_row, tks_row1, tks_row2, wt_grad1,
tks_row1, tks_row2, wt_grad1, wt_grad2, wt_grad2, weight1, weight2, cols);
weight1, weight2, cols); else if (indicator1 != 0)
else if (indicator1 != 0) moe_cb_one_bwd<T, block_size, pack_size>(src_row1, dst_row, tks_row1,
moe_cb_one_bwd<T, block_size, pack_size>( wt_grad1, weight1, cols);
src_row1, dst_row, tks_row1, wt_grad1, weight1, cols); else if (indicator2 != 0)
else if (indicator2 != 0) moe_cb_one_bwd<T, block_size, pack_size>(src_row2, dst_row, tks_row2,
moe_cb_one_bwd<T, block_size, pack_size>( wt_grad2, weight2, cols);
src_row2, dst_row, tks_row2, wt_grad2, weight2, cols); else
else return;
return;
} }
template <typename T, int block_size, int pack_size>
template<typename T, int block_size, int pack_size> __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
__global__ void moe_cb_fwd_kernel( T *logits, int *mask1, int *mask2, int *dest1,
T *expert_tokens, T *combine_tokens, T *logits, int *dest2, const int e, const int c,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2,
const int e, const int c, const int h) { int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; T *row_log = logits + (row * e);
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_cb_fwd_selector<T, block_size, pack_size>(
T *row_log = logits + (row * e); expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h),
moe_cb_fwd_selector<T, block_size, pack_size>( combine_tokens + (row * h), h, row_log[eid1], row_log[eid2], mask1[row],
expert_tokens + (dest1[row] * h), expert_tokens + (dest2[row] * h), indicator2);
combine_tokens + (row * h), h,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
} }
template<typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__global__ void moe_cb_bwd_kernel( __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *tokens_grad, T *expert_grad, T *tks, T *logits, T *logits_grad, int *mask1,
T *logits, T *logits_grad, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2, const int e, const int c, const int h) {
int *dest1, int *dest2,
const int e, const int c, const int h) { int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; moe_cb_bwd_selector<T, block_size, pack_size>(
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h),
moe_cb_bwd_selector<T, block_size, pack_size>( tokens_grad + (row * h), h, tks + (dest1[row] * h),
expert_grad + (dest1[row] * h), expert_grad + (dest2[row] * h), tks + (dest2[row] * h), row_grad + eid1, row_grad + eid2, row_log[eid1],
tokens_grad + (row * h), h, row_log[eid2], mask1[row], indicator2);
tks + (dest1[row] * h), tks + (dest2[row] * h),
row_grad + eid1, row_grad + eid2,
row_log[eid1], row_log[eid2],
mask1[row], indicator2);
} }
//CUMSUM KERNEL -------------------------------- // CUMSUM KERNEL --------------------------------
template<int block_size, int pack_size> template <int block_size, int pack_size>
__global__ void cumsum_kernel( __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
int *inputs, int *outputs, const int e) {
const int s, const int e) {
assert(s % pack_size == 0);
assert(s % pack_size == 0); constexpr int bpack_size = block_size * pack_size;
constexpr int bpack_size = block_size * pack_size; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
int tid = threadIdx.x, bid = blockIdx.x, __shared__ int temp[block_size + 1];
tps = tid * pack_size, last_sum = -1; int pack[pack_size];
__shared__ int temp[block_size + 1]; int pack[pack_size];
for (int idx = 0; idx < s; idx += bpack_size) {
for (int idx = 0; idx < s; idx += bpack_size) { int offset = 1;
int offset = 1;
if (idx + tps < s) {
if (idx + tps < s) { temp[tid] = inputs[tps * e + bid];
temp[tid] = inputs[tps * e + bid]; #pragma unroll
#pragma unroll for (int i = 1; i < pack_size; ++i) {
for (int i = 1; i < pack_size; ++i) { pack[i] = inputs[(tps + i) * e + bid];
pack[i] = inputs[(tps + i) * e + bid]; }
} #pragma unroll
#pragma unroll for (int i = 1; i < pack_size; ++i) {
for (int i = 1; i < pack_size; ++i) { temp[tid] += pack[i];
temp[tid] += pack[i]; }
} }
}
for (int i = block_size >> 1; i > 0; i >>= 1) {
for (int i = block_size >> 1; i > 0; i >>= 1) { __syncthreads();
__syncthreads(); if (tid < i) {
if (tid < i) { int j = offset * (2 * tid + 1) - 1;
int j = offset * (2 * tid + 1) - 1; temp[j + offset] += temp[j];
temp[j + offset] += temp[j]; }
} offset <<= 1;
offset <<= 1;
}
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1,
k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0)
temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
} }
if (tid == 0) {
temp[block_size] = temp[block_size - 1];
temp[block_size - 1] = 0;
}
for (int i = 1; i < block_size; i <<= 1) {
offset >>= 1;
__syncthreads();
if (tid < i) {
int j = offset * (2 * tid + 1) - 1, k = j + offset, ts = temp[j];
temp[j] = temp[k];
temp[k] += ts;
}
}
__syncthreads();
if (tid == 0)
temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
temp[tid + 1] += last_sum;
#pragma unroll
for (int i = pack_size - 1; i > 0; --i) {
outputs[(tps + i) * e + bid] = temp[tid + 1];
temp[tid + 1] -= pack[i];
}
outputs[tps * e + bid] = temp[tid + 1];
}
__syncthreads();
last_sum += temp[0];
inputs += bpack_size * e;
outputs += bpack_size * e;
}
} }
//LAUNCH FUNCTIONS -------------------------------- // LAUNCH FUNCTIONS --------------------------------
template<typename T> template <typename T>
void moe_dpch_fwd_launch( void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
T *batch_tokens, T *expert_input, int *mask2, int *dest1, int *dest2, const int s,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2,
const int s, const int h) { if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4>
if (h < 256) <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
moe_dpch_fwd_kernel<T, 32, 4><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); else if (h < 512)
else if (h < 512) moe_dpch_fwd_kernel<T, 32, 8>
moe_dpch_fwd_kernel<T, 32, 8><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else if (h < 1024) else if (h < 1024)
moe_dpch_fwd_kernel<T, 32, 16><<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); moe_dpch_fwd_kernel<T, 32, 16>
else if (h < 2048) <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
moe_dpch_fwd_kernel<T, 64, 16><<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); else if (h < 2048)
else moe_dpch_fwd_kernel<T, 64, 16>
moe_dpch_fwd_kernel<T, 128, 16><<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); <<<s, 64>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
else
moe_dpch_fwd_kernel<T, 128, 16>
<<<s, 128>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
} }
template<typename T> template <typename T>
void moe_dpch_bwd_launch( void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
T *tokens_grad, T *expert_grad, int *dest1, int *dest2, const int s, const int h) {
int *mask1, int *mask2,
int *dest1, int *dest2, if (h < 256)
const int s, const int h) { moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
if (h < 256) else if (h < 512)
moe_dpch_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 32, 8>
else if (h < 512) <<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
moe_dpch_bwd_kernel<T, 32, 8><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); else if (h < 1024)
else if (h < 1024) moe_dpch_bwd_kernel<T, 32, 16>
moe_dpch_bwd_kernel<T, 32, 16><<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
else if (h < 2048) else if (h < 2048)
moe_dpch_bwd_kernel<T, 64, 16><<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); moe_dpch_bwd_kernel<T, 64, 16>
else <<<s, 64>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
moe_dpch_bwd_kernel<T, 128, 16><<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); else
moe_dpch_bwd_kernel<T, 128, 16>
<<<s, 128>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
} }
template<typename T> template <typename T>
void moe_cb_fwd_launch( void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
T *expert_tokens, T *combine_tokens, T *logits, int *mask1, int *mask2, int *dest1, int *dest2,
int *mask1, int *mask2, const int s, const int e, const int c, const int h) {
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
if (h < 256) logits, mask1, mask2, dest1, dest2,
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>> e, c, h);
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); else if (h < 512)
else if (h < 512) moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>>(expert_tokens, combine_tokens,
moe_cb_fwd_kernel<T, 32, 8><<<s, 32>>> logits, mask1, mask2, dest1, dest2,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); e, c, h);
else if (h < 1024) else if (h < 1024)
moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>> moe_cb_fwd_kernel<T, 32, 16><<<s, 32>>>(expert_tokens, combine_tokens,
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); logits, mask1, mask2, dest1, dest2,
else if (h < 2048) e, c, h);
moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>> else if (h < 2048)
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); moe_cb_fwd_kernel<T, 64, 16><<<s, 64>>>(expert_tokens, combine_tokens,
else logits, mask1, mask2, dest1, dest2,
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>> e, c, h);
(expert_tokens, combine_tokens, logits, mask1, mask2, dest1, dest2, e, c, h); else
moe_cb_fwd_kernel<T, 128, 16><<<s, 128>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1,
dest2, e, c, h);
} }
template<typename T> template <typename T>
void moe_cb_bwd_launch( void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *tokens_grad, T *expert_grad, T *tks, T *logits_grad, int *mask1, int *mask2, int *dest1,
T *logits, T *logits_grad, int *dest2, const int s, const int e, const int c,
int *mask1, int *mask2, const int h) {
int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
if (h < 256) logits, logits_grad, mask1, mask2,
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>> dest1, dest2, e, c, h);
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); else // if (h < 512)
else // if (h < 512) moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>> logits, logits_grad, mask1, mask2,
(tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); dest1, dest2, e, c, h);
// else if (h < 1024) // else if (h < 1024)
// moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>> // moe_cb_bwd_kernel<T, 128, 4><<<s, 128>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); // (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// else // dest1, dest2, e, c, h);
// moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>> // else
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2, dest1, dest2, e, c, h); // moe_cb_bwd_kernel<T, 256, 4><<<s, 256>>>
// (tokens_grad, expert_grad, tks, logits, logits_grad, mask1, mask2,
// dest1, dest2, e, c, h);
} }
void cumsum_launch( void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
int *inputs, int *outputs,
const int s, const int e) { if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
if (s <= 256) else if (s <= 512)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e); cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e);
else if (s <= 512) else if (s <= 1024)
cumsum_kernel<512, 1><<<e, 512>>>(inputs, outputs, s, e); cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 1024) else if (s <= 2048)
cumsum_kernel<1024, 1><<<e, 1024>>>(inputs, outputs, s, e); cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e);
else if (s <= 2048) else
cumsum_kernel<1024, 2><<<e, 1024>>>(inputs, outputs, s, e); cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
else
cumsum_kernel<1024, 4><<<e, 1024>>>(inputs, outputs, s, e);
} }
// API FUNCTIONS -------------------------------- // API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) \ switch (TYPE) { \
{ \ case at::ScalarType::Float: { \
case at::ScalarType::Float: \ using scalar_t = float; \
{ \ __VA_ARGS__; \
using scalar_t = float; \ break; \
__VA_ARGS__; \ } \
break; \ case at::ScalarType::Half: { \
} \ using scalar_t = at::Half; \
case at::ScalarType::Half: \ __VA_ARGS__; \
{ \ break; \
using scalar_t = at::Half; \ } \
__VA_ARGS__; \ default: \
break; \ AT_ERROR(#NAME, " not implemented yet for specific data type."); \
} \ }
default: \
AT_ERROR(#NAME, " not implemented yet for specific data type.");\ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
} torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor moe_dispatch_cuda_forward( torch::Tensor dest_idx) {
int s, int ec, int h,
torch::Tensor batch_tokens, assert(h % 16 == 0);
torch::Tensor mask, auto res = torch::zeros(
torch::Tensor dest_idx) { {ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device()));
assert(h % 16 == 0); auto k = mask.size(0);
auto res = torch::zeros({ec, h},
torch::dtype(batch_tokens.dtype()).device(batch_tokens.device())); DISPATCH_FLOAT_AND_HALF(
auto k = mask.size(0); batch_tokens.scalar_type(), "moe dispatch forward",
moe_dpch_fwd_launch<scalar_t>(
DISPATCH_FLOAT_AND_HALF( batch_tokens.data<scalar_t>(), res.data<scalar_t>(),
batch_tokens.scalar_type(), "moe dispatch forward", mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
moe_dpch_fwd_launch<scalar_t>( dest_idx[0].data<int>(),
batch_tokens.data<scalar_t>(), res.data<scalar_t>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), return res;
s, h)
);
return res;
} }
torch::Tensor moe_dispatch_cuda_backward( torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
int s, int ec, int h, torch::Tensor expert_grad,
torch::Tensor expert_grad, torch::Tensor mask,
torch::Tensor mask, torch::Tensor dest_idx) {
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(h % 16 == 0); auto res = torch::zeros(
auto res = torch::zeros({s, h}, {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
torch::dtype(expert_grad.dtype()).device(expert_grad.device())); auto k = mask.size(0);
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF( expert_grad.scalar_type(), "moe dispatch backward",
expert_grad.scalar_type(), "moe dispatch backward", moe_dpch_bwd_launch<scalar_t>(
moe_dpch_bwd_launch<scalar_t>( res.data<scalar_t>(), expert_grad.data<scalar_t>(),
res.data<scalar_t>(), expert_grad.data<scalar_t>(), mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, h));
s, h)
); return res;
return res;
} }
torch::Tensor moe_combine_cuda_forward( torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
int s, int e, int c, int h, torch::Tensor expert_tokens,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor logits, torch::Tensor dest_idx) {
torch::Tensor mask,
torch::Tensor dest_idx) { assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype()); auto res = torch::zeros(
{s, h},
auto res = torch::zeros({s, h}, torch::dtype(expert_tokens.dtype()).device(expert_tokens.device()));
torch::dtype(expert_tokens.dtype()).device(expert_tokens.device())); auto k = mask.size(0);
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF( expert_tokens.scalar_type(), "moe combine forward",
expert_tokens.scalar_type(), "moe combine forward", moe_cb_fwd_launch<scalar_t>(
moe_cb_fwd_launch<scalar_t>( expert_tokens.data<scalar_t>(), res.data<scalar_t>(),
expert_tokens.data<scalar_t>(), res.data<scalar_t>(), logits.data<scalar_t>(), logits.data<scalar_t>(), mask[0].data<int>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
s, e, c, h) h));
);
return res;
return res;
} }
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor tokens_grad, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor expert_tokens, torch::Tensor mask, torch::Tensor dest_idx) {
torch::Tensor logits,
torch::Tensor mask, assert(h % 16 == 0);
torch::Tensor dest_idx) { assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype()); auto egrad = torch::zeros(
assert(expert_tokens.dtype() == logits.dtype()); {e * c, h},
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())),
auto egrad = torch::zeros({e * c, h}, wgrad = torch::zeros(
torch::dtype(tokens_grad.dtype()).device(tokens_grad.device())), {s, e}, torch::dtype(logits.dtype()).device(logits.device()));
wgrad = torch::zeros({s, e}, torch::dtype(logits.dtype()).device(logits.device())); auto k = mask.size(0);
auto k = mask.size(0);
DISPATCH_FLOAT_AND_HALF(
DISPATCH_FLOAT_AND_HALF( tokens_grad.scalar_type(), "moe combine backward",
tokens_grad.scalar_type(), "moe combine backward", moe_cb_bwd_launch<scalar_t>(
moe_cb_bwd_launch<scalar_t>( tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(),
tokens_grad.data<scalar_t>(), egrad.data<scalar_t>(), expert_tokens.data<scalar_t>(), expert_tokens.data<scalar_t>(), logits.data<scalar_t>(),
logits.data<scalar_t>(), wgrad.data<scalar_t>(), wgrad.data<scalar_t>(), mask[0].data<int>(),
mask[0].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), k == 1 ? nullptr : mask[1].data<int>(), dest_idx[0].data<int>(),
dest_idx[0].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), k == 1 ? dest_idx[0].data<int>() : dest_idx[1].data<int>(), s, e, c,
s, e, c, h) h));
);
return {egrad, wgrad};
return {egrad, wgrad};
} }
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2); assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32); assert(mask.dtype() == torch::kInt32);
const int s = mask.size(0), e = mask.size(1); const int s = mask.size(0), e = mask.size(1);
auto res = torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device())); auto res =
cumsum_launch(mask.data<int>(), res.data<int>(), s, e); torch::empty({s, e}, torch::dtype(torch::kInt32).device(mask.device()));
cumsum_launch(mask.data<int>(), res.data<int>(), s, e);
return res; return res;
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment