Unverified Commit 58580b50 authored by ver217's avatar ver217 Committed by GitHub
Browse files

Revert "[NFC] Hotfix/format (#984)" (#986)

This reverts commit 0772828f.
parent 0772828f
......@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1'
......@@ -251,8 +251,8 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = []
for start, end in partitions[pipeline_rank]:
module_list.append(
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end],
module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
*layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)]))
if verbose:
logger = get_dist_logger()
......@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
\ No newline at end of file
......@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE
*/
#include "cpu_adam.h"
#include <iostream>
#include <math.h>
#include <memory>
#include <omp.h>
#include <string.h>
#include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits>
#include <unordered_map>
......@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
......@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t;
if ((t + TILE) > _param_size)
copy_size = _param_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
......@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
......@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt;
if (should_log) {
std::string avx_type = "";
#if defined(__AVX512__)
avx_type = "AVX512";
......@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t;
if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t;
#pragma omp parallel for
......@@ -460,27 +463,41 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale);
}
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2,
float epsilon, float weight_decay, bool bias_correction,
torch::Tensor &params, torch::Tensor &grads,
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq,
float loss_scale) {
int adam_step(int optimizer_id,
size_t step,
float lr,
float beta1,
float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
float loss_scale)
{
auto params_c = params.contiguous();
auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr();
float* params_ptr = (float*)params_c.data_ptr();
float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf), loss_scale);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);
return 0;
}
......
......@@ -90,18 +90,12 @@ union AVX_Data {
bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer {
public:
public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true)
: _alpha(alpha),
_betta1(betta1),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
: _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps),
_weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0),
_adamw_mode(adamw_mode) {}
~Adam_Optimizer() {}
......@@ -141,7 +135,7 @@ class Adam_Optimizer {
}
}
private:
private:
float _alpha;
float _betta1;
float _betta2;
......
#include <cooperative_groups.h>
#include <chrono>
#include <ctime>
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate;
......@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;
uint8_t m[4];
......@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;
float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
......@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
......@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
}
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum;
if (y == 0)
tile[0][x] = sum;
__syncthreads();
if (threadIdx.x < 8) {
......@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return;
if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return;
if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state);
......@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum;
if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
__syncthreads();
if (threadIdx.y == 0) {
......
#include <cooperative_groups.h>
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
/**
......
......@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32;
template <typename T>
__forceinline__ __device__ T warpReduceSum(T val) {
template <typename T> __forceinline__ __device__ T warpReduceSum(T val) {
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val;
}
/* Calculate the sum of all elements in a block */
template <typename T>
__forceinline__ __device__ T blockReduceSum(T val) {
template <typename T> __forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val;
if (lane == 0)
shared[wid] = val;
__syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
......
......@@ -9,7 +9,7 @@
#include "cuda_util.h"
class Context {
public:
public:
Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
}
......@@ -30,7 +30,7 @@ class Context {
cublasHandle_t get_cublashandle() { return _cublasHandle; }
private:
private:
cudaStream_t _stream;
cublasHandle_t _cublasHandle;
};
......@@ -8,9 +8,8 @@
#include "cuda_util.h"
template <typename T>
class CrossEntropyLayer {
public:
template <typename T> class CrossEntropyLayer {
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer();
......@@ -23,7 +22,7 @@ class CrossEntropyLayer {
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private:
private:
void allocate_mem_buffer() {
// allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
......
......@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T>
T *cuda_malloc(size_t ele_num);
template <typename T> T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata);
......
......@@ -3,14 +3,12 @@
#include <cuda.h>
#include <cuda_fp16.h>
#include <stdio.h>
#include <string>
#include "kernels.h"
template <typename T>
class Dropout {
public:
template <typename T> class Dropout {
public:
struct Config {
float ratio;
bool training;
......@@ -90,7 +88,7 @@ class Dropout {
void SetTrainingMode(bool training) { _config.training = training; }
private:
private:
uint8_t *_mask;
Config _config;
};
......@@ -13,16 +13,14 @@
#include "cublas_wrappers.h"
#include "kernels.h"
template <typename T>
class FeedForward {
public:
template <typename T> class FeedForward {
public:
struct Config {
int outputSize;
int inputSize;
std::array<int, 3> gemm_algos;
Config(int outputs, int inputs)
: outputSize(outputs),
inputSize(inputs),
: outputSize(outputs), inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {}
};
......@@ -63,6 +61,6 @@ class FeedForward {
config_.inputSize = inputSize;
}
private:
private:
Config config_;
};
......@@ -10,9 +10,8 @@
using namespace std;
template <typename T>
class Softmax {
public:
template <typename T> class Softmax {
public:
struct Config {
size_t nhead;
Config(size_t nhead) : nhead(nhead) {}
......@@ -37,6 +36,6 @@ class Softmax {
void reset_size(size_t nhead) { config_.nhead = nhead; }
private:
private:
Config config_;
};
#include "block_reduce.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
......
#include <cooperative_groups.h>
#include <math.h>
#include <cub/block/block_load.cuh>
......@@ -7,6 +6,8 @@
#include "block_reduce.h"
#include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
const float EPSILON = 1e-8f;
......@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i);
for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_xor(sum, i);
#pragma unroll
for (int i = 0; i < ITERATIONS; ++i) {
......
......@@ -2,13 +2,11 @@
* https://github.com/NVIDIA/apex
* with minor changes. */
#include <torch/extension.h>
#include "compat.h"
#include <cassert>
#include <torch/extension.h>
#include <vector>
#include "compat.h"
namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
......@@ -85,6 +83,7 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta,
double epsilon) {
CHECK_INPUT(input);
CHECK_INPUT(gamma);
CHECK_INPUT(beta);
......@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine(
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta,
double epsilon) {
std::vector<at::Tensor>
layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::Tensor input, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(dout);
CHECK_INPUT(mean);
CHECK_INPUT(invvar);
......
......@@ -15,10 +15,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx);
std::vector<torch::Tensor>
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor mask, torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
......@@ -33,6 +33,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
......@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad);
CHECK_CUDA(mask);
CHECK_CUDA(dest_idx);
......@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
......@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx);
}
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h,
torch::Tensor tokens_grad,
torch::Tensor expert_tokens,
torch::Tensor logits,
torch::Tensor mask,
torch::Tensor dest_idx) {
std::vector<torch::Tensor>
moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits);
CHECK_CUDA(mask);
......
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum);
if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
}
template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2,
const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1,
const T weight2, const int cols) {
assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size;
......@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
......@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols);
......@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1,
int *dest2, const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>(
......@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int h) {
int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>(
......@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols);
......@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1,
const int indicator2) {
if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1,
......@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c,
const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e);
......@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
......@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) {
assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
......@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
}
__syncthreads();
if (tid == 0) temp[0] = temp[block_size];
if (tid == 0)
temp[0] = temp[block_size];
__syncthreads();
if (idx + tps < s) {
......@@ -436,6 +453,7 @@ template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s,
const int h) {
if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
......@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) {
if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
......@@ -477,6 +496,7 @@ template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) {
if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2,
......@@ -504,6 +524,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c,
const int h) {
if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2,
......@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
}
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512)
......@@ -557,6 +579,7 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{ec, h},
......@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad,
torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
......@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype());
......@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return res;
}
std::vector<torch::Tensor> moe_combine_cuda_backward(
int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) {
std::vector<torch::Tensor>
moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor mask, torch::Tensor dest_idx) {
assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype());
......@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
}
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32);
......
......@@ -16,8 +16,7 @@
#define BLOCK_SIZE 512
#define ILP 4
template <typename T>
__device__ __forceinline__ bool is_aligned(T *p) {
template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
}
......@@ -29,10 +28,9 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
}
template <typename x_t>
struct L2NormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
template <typename x_t> struct L2NormFunctor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
......@@ -50,8 +48,8 @@ struct L2NormFunctor {
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
float
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
......@@ -86,7 +84,8 @@ struct L2NormFunctor {
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i];
for (int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val);
......@@ -105,10 +104,9 @@ struct L2NormFunctor {
// Probably better to template, but since we are not likely to support other
// norm
template <typename x_t>
struct MaxNormFunctor {
__device__ __forceinline__ void operator()(
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
template <typename x_t> struct MaxNormFunctor {
__device__ __forceinline__ void
operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans.
......@@ -126,8 +124,8 @@ struct MaxNormFunctor {
__shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be
// sure...
float
vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP];
for (int i = 0; i < ILP; i++) {
vals[i] = 0.f;
......@@ -162,7 +160,8 @@ struct MaxNormFunctor {
}
float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i]));
for (int i = 0; i < ILP; i++)
val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val);
......@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
if (threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final);
if (threadIdx.x == 0)
*ret = sqrt(final);
}
if (per_tensor) {
......@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final);
if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
}
}
......@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) {
float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x];
if (threadIdx.x < 320)
val = output[threadIdx.x];
if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final;
if (threadIdx.x == 0)
*ret = alpha * (*ret) + beta * final;
} else {
float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
if (threadIdx.x == 0)
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
}
}
......@@ -255,8 +260,8 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
}
}
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
int chunk_size, at::Tensor noop_flag,
std::tuple<at::Tensor, at::Tensor>
multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) {
bool per_tensor =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment