Unverified Commit 58580b50 authored by ver217's avatar ver217 Committed by GitHub
Browse files

Revert "[NFC] Hotfix/format (#984)" (#986)

This reverts commit 0772828f.
parent 0772828f
...@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi, ...@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser) launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1' __version__ = '0.0.1'
...@@ -251,8 +251,8 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo ...@@ -251,8 +251,8 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = [] module_list = []
for start, end in partitions[pipeline_rank]: for start, end in partitions[pipeline_rank]:
module_list.append( module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)])) *[nn.Identity() for _ in range(len(layers) - end)]))
if verbose: if verbose:
logger = get_dist_logger() logger = get_dist_logger()
...@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo ...@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
\ No newline at end of file
...@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE ...@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
*/ */
#include "cpu_adam.h" #include "cpu_adam.h"
#include <iostream>
#include <math.h> #include <math.h>
#include <memory>
#include <omp.h> #include <omp.h>
#include <string.h> #include <string.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
...@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, ...@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, ...@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) { for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t; if ((t + TILE) > _param_size)
copy_size = _param_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, ...@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3, ...@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt; s_optimizers[optimizer_id] = opt;
if (should_log) { if (should_log) {
std::string avx_type = ""; std::string avx_type = "";
#if defined(__AVX512__) #if defined(__AVX512__)
avx_type = "AVX512"; avx_type = "AVX512";
...@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, ...@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -460,27 +463,41 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, ...@@ -460,27 +463,41 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale); grad_half_precision, loss_scale);
} }
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2, int adam_step(int optimizer_id,
float epsilon, float weight_decay, bool bias_correction, size_t step,
torch::Tensor &params, torch::Tensor &grads, float lr,
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq, float beta1,
float loss_scale) { float beta2,
float epsilon,
float weight_decay,
bool bias_correction,
torch::Tensor& params,
torch::Tensor& grads,
torch::Tensor& exp_avg,
torch::Tensor& exp_avg_sq,
float loss_scale)
{
auto params_c = params.contiguous(); auto params_c = params.contiguous();
auto grads_c = grads.contiguous(); auto grads_c = grads.contiguous();
auto exp_avg_c = exp_avg.contiguous(); auto exp_avg_c = exp_avg.contiguous();
auto exp_avg_sq_c = exp_avg_sq.contiguous(); auto exp_avg_sq_c = exp_avg_sq.contiguous();
float *params_ptr = (float *)params_c.data_ptr(); float* params_ptr = (float*)params_c.data_ptr();
float *grads_ptr = (float *)grads_c.data_ptr(); float* grads_ptr = (float*)grads_c.data_ptr();
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt = std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2); opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction); opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr, opt->Step_8(params_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf), grads_ptr,
(grads.options().dtype() == at::kHalf), loss_scale); exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);
return 0; return 0;
} }
......
...@@ -90,18 +90,12 @@ union AVX_Data { ...@@ -90,18 +90,12 @@ union AVX_Data {
bool grad_half_precision = false, float loss_scale = -1); bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer { class Adam_Optimizer {
public: public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0, float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true) bool adamw_mode = true)
: _alpha(alpha), : _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps),
_betta1(betta1), _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {} _adamw_mode(adamw_mode) {}
~Adam_Optimizer() {} ~Adam_Optimizer() {}
...@@ -141,7 +135,7 @@ class Adam_Optimizer { ...@@ -141,7 +135,7 @@ class Adam_Optimizer {
} }
} }
private: private:
float _alpha; float _alpha;
float _betta1; float _betta1;
float _betta2; float _betta2;
......
#include <cooperative_groups.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;
...@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
uint8_t m[4]; uint8_t m[4];
...@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
...@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0)
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0)
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel( ...@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
......
#include <cooperative_groups.h>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
/** /**
......
...@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f; ...@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f; const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32; const unsigned int WARP_REDUCE_SIZE = 32;
template <typename T> template <typename T> __forceinline__ __device__ T warpReduceSum(T val) {
__forceinline__ __device__ T warpReduceSum(T val) {
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val; return val;
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template <typename T> template <typename T> __forceinline__ __device__ T blockReduceSum(T val) {
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val; if (lane == 0)
shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda_util.h" #include "cuda_util.h"
class Context { class Context {
public: public:
Context() : _stream(nullptr) { Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
} }
...@@ -30,7 +30,7 @@ class Context { ...@@ -30,7 +30,7 @@ class Context {
cublasHandle_t get_cublashandle() { return _cublasHandle; } cublasHandle_t get_cublashandle() { return _cublasHandle; }
private: private:
cudaStream_t _stream; cudaStream_t _stream;
cublasHandle_t _cublasHandle; cublasHandle_t _cublasHandle;
}; };
...@@ -8,9 +8,8 @@ ...@@ -8,9 +8,8 @@
#include "cuda_util.h" #include "cuda_util.h"
template <typename T> template <typename T> class CrossEntropyLayer {
class CrossEntropyLayer { public:
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer(); virtual ~CrossEntropyLayer();
...@@ -23,7 +22,7 @@ class CrossEntropyLayer { ...@@ -23,7 +22,7 @@ class CrossEntropyLayer {
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private: private:
void allocate_mem_buffer() { void allocate_mem_buffer() {
// allocate local gpu memory // allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2); _loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
......
...@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file, ...@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template <typename T> template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele); void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T> template <typename T> T *cuda_malloc(size_t ele_num);
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata); void cuda_free(void *pdata);
......
...@@ -3,14 +3,12 @@ ...@@ -3,14 +3,12 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T> class Dropout {
class Dropout { public:
public:
struct Config { struct Config {
float ratio; float ratio;
bool training; bool training;
...@@ -90,7 +88,7 @@ class Dropout { ...@@ -90,7 +88,7 @@ class Dropout {
void SetTrainingMode(bool training) { _config.training = training; } void SetTrainingMode(bool training) { _config.training = training; }
private: private:
uint8_t *_mask; uint8_t *_mask;
Config _config; Config _config;
}; };
...@@ -13,16 +13,14 @@ ...@@ -13,16 +13,14 @@
#include "cublas_wrappers.h" #include "cublas_wrappers.h"
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T> class FeedForward {
class FeedForward { public:
public:
struct Config { struct Config {
int outputSize; int outputSize;
int inputSize; int inputSize;
std::array<int, 3> gemm_algos; std::array<int, 3> gemm_algos;
Config(int outputs, int inputs) Config(int outputs, int inputs)
: outputSize(outputs), : outputSize(outputs), inputSize(inputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} gemm_algos(std::array<int, 3>({99, 99, 99})) {}
}; };
...@@ -63,6 +61,6 @@ class FeedForward { ...@@ -63,6 +61,6 @@ class FeedForward {
config_.inputSize = inputSize; config_.inputSize = inputSize;
} }
private: private:
Config config_; Config config_;
}; };
...@@ -10,9 +10,8 @@ ...@@ -10,9 +10,8 @@
using namespace std; using namespace std;
template <typename T> template <typename T> class Softmax {
class Softmax { public:
public:
struct Config { struct Config {
size_t nhead; size_t nhead;
Config(size_t nhead) : nhead(nhead) {} Config(size_t nhead) : nhead(nhead) {}
...@@ -37,6 +36,6 @@ class Softmax { ...@@ -37,6 +36,6 @@ class Softmax {
void reset_size(size_t nhead) { config_.nhead = nhead; } void reset_size(size_t nhead) { config_.nhead = nhead; }
private: private:
Config config_; Config config_;
}; };
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
......
#include <cooperative_groups.h>
#include <math.h> #include <math.h>
#include <cub/block/block_load.cuh> #include <cub/block/block_load.cuh>
...@@ -7,6 +6,8 @@ ...@@ -7,6 +6,8 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
const float EPSILON = 1e-8f; const float EPSILON = 1e-8f;
...@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { ...@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_xor(sum, i);
#pragma unroll #pragma unroll
for (int i = 0; i < ITERATIONS; ++i) { for (int i = 0; i < ITERATIONS; ++i) {
......
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
* https://github.com/NVIDIA/apex * https://github.com/NVIDIA/apex
* with minor changes. */ * with minor changes. */
#include <torch/extension.h> #include "compat.h"
#include <cassert> #include <cassert>
#include <torch/extension.h>
#include <vector> #include <vector>
#include "compat.h"
namespace { namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
...@@ -85,6 +83,7 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input, ...@@ -85,6 +83,7 @@ std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, at::Tensor gamma, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
...@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, ...@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double epsilon, at::Tensor *grad_input, double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta); at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor>
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, at::Tensor input, at::IntArrayRef normalized_shape,
double epsilon) { at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);
......
...@@ -15,10 +15,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -15,10 +15,10 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx); torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, torch::Tensor tokens_grad, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor dest_idx); torch::Tensor mask, torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
...@@ -33,6 +33,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); ...@@ -33,6 +33,7 @@ torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens); CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
...@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h, ...@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad); CHECK_INPUT(expert_grad);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
...@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, ...@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens); CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
...@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, ...@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx); dest_idx);
} }
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h, std::vector<torch::Tensor>
torch::Tensor tokens_grad, moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) {
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad); CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
......
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { ...@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { ...@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, ...@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, ...@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, ...@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) { T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, ...@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce<ReduceType::kSum, 1>(&thread_sum); blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum); if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
} }
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2, const T weight1, const T weight2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, ...@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1, T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1, T *weight_grad2, const T weight1,
const T weight2, const int cols) { const T weight2, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size> ...@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
...@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size> ...@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
...@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size> ...@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1, int *mask1, int *mask2, int *dest1,
int *dest2, const int h) { int *dest2, const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>( moe_dpch_fwd_selector<T, block_size, pack_size>(
...@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size> ...@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int h) { const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>( moe_dpch_bwd_selector<T, block_size, pack_size>(
...@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, ...@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1, const int cols, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols); weight1, weight2, cols);
...@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, ...@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *wt_grad1, T *wt_grad2, const T weight1, T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, tks_row1, tks_row2, wt_grad1,
...@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, ...@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1, T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c, int *dest2, const int e, const int c,
const int h) { const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e); T *row_log = logits + (row * e);
...@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, ...@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1, T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) { const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
...@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, ...@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template <int block_size, int pack_size> template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) { const int e) {
assert(s % pack_size == 0); assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size; constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
...@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, ...@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
} }
__syncthreads(); __syncthreads();
if (tid == 0) temp[0] = temp[block_size]; if (tid == 0)
temp[0] = temp[block_size];
__syncthreads(); __syncthreads();
if (idx + tps < s) { if (idx + tps < s) {
...@@ -436,6 +453,7 @@ template <typename T> ...@@ -436,6 +453,7 @@ template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s, int *mask2, int *dest1, int *dest2, const int s,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4> moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
...@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, ...@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template <typename T> template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) { int *dest1, int *dest2, const int s, const int h) {
if (h < 256) if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4> moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
...@@ -477,6 +496,7 @@ template <typename T> ...@@ -477,6 +496,7 @@ template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2, int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { const int s, const int e, const int c, const int h) {
if (h < 256) if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens, moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2, logits, mask1, mask2, dest1, dest2,
...@@ -504,6 +524,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, ...@@ -504,6 +524,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1, T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c, int *dest2, const int s, const int e, const int c,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks, moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2, logits, logits_grad, mask1, mask2,
...@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, ...@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
} }
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256) if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e); cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512) else if (s <= 512)
...@@ -557,6 +579,7 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, ...@@ -557,6 +579,7 @@ torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{ec, h}, {ec, h},
...@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, ...@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
...@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
...@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return res; return res;
} }
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, torch::Tensor tokens_grad, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
...@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward( ...@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
} }
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2); assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32); assert(mask.dtype() == torch::kInt32);
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
...@@ -29,10 +28,9 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ...@@ -29,10 +28,9 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
} }
template <typename x_t> template <typename x_t> struct L2NormFunctor {
struct L2NormFunctor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor, float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) { int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
...@@ -50,8 +48,8 @@ struct L2NormFunctor { ...@@ -50,8 +48,8 @@ struct L2NormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be float
// sure... vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
...@@ -86,7 +84,8 @@ struct L2NormFunctor { ...@@ -86,7 +84,8 @@ struct L2NormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i]; for (int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val); float final = reduce_block_into_lanes(s_vals, val);
...@@ -105,10 +104,9 @@ struct L2NormFunctor { ...@@ -105,10 +104,9 @@ struct L2NormFunctor {
// Probably better to template, but since we are not likely to support other // Probably better to template, but since we are not likely to support other
// norm // norm
template <typename x_t> template <typename x_t> struct MaxNormFunctor {
struct MaxNormFunctor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
float *output, float *output_per_tensor, bool per_tensor, float *output, float *output_per_tensor, bool per_tensor,
int max_chunks_per_tensor) { int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
...@@ -126,8 +124,8 @@ struct MaxNormFunctor { ...@@ -126,8 +124,8 @@ struct MaxNormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be float
// sure... vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
...@@ -162,7 +160,8 @@ struct MaxNormFunctor { ...@@ -162,7 +160,8 @@ struct MaxNormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); for (int i = 0; i < ILP; i++)
val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val); float final = reduce_block_into_lanes_max_op(s_vals, val);
...@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, ...@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x]; if (threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final); if (threadIdx.x == 0)
*ret = sqrt(final);
} }
if (per_tensor) { if (per_tensor) {
...@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, ...@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
} }
} }
...@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x]; if (threadIdx.x < 320)
val = output[threadIdx.x];
if (norm_type == 0) { if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val); float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; if (threadIdx.x == 0)
*ret = alpha * (*ret) + beta * final;
} else { } else {
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); if (threadIdx.x == 0)
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
} }
} }
...@@ -255,8 +260,8 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -255,8 +260,8 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
} }
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor>
int chunk_size, at::Tensor noop_flag, multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) { at::optional<bool> per_tensor_python) {
bool per_tensor = bool per_tensor =
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment