Unverified Commit 58580b50 authored by ver217's avatar ver217 Committed by GitHub
Browse files

Revert "[NFC] Hotfix/format (#984)" (#986)

This reverts commit 0772828f.
parent 0772828f
...@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi, ...@@ -2,4 +2,3 @@ from .initialize import (initialize, launch, launch_from_openmpi,
launch_from_slurm, launch_from_torch, get_default_parser) launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1' __version__ = '0.0.1'
...@@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo ...@@ -251,9 +251,9 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks) partitions = partition_uniform(len(layers), pipeline_parallel_size, num_chunks)
module_list = [] module_list = []
for start, end in partitions[pipeline_rank]: for start, end in partitions[pipeline_rank]:
module_list.append( module_list.append(nn.Sequential(*[nn.Identity() for _ in range(start)],
nn.Sequential(*[nn.Identity() for _ in range(start)], *layers[start:end], *layers[start:end],
*[nn.Identity() for _ in range(len(layers) - end)])) *[nn.Identity() for _ in range(len(layers) - end)]))
if verbose: if verbose:
logger = get_dist_logger() logger = get_dist_logger()
logger.info(f'Total {len(layers)} layers', ranks=[0]) logger.info(f'Total {len(layers)} layers', ranks=[0])
...@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo ...@@ -264,3 +264,4 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n' log_str += '\n'.join([str(layer) for layer in layers[start:end]]) + '\n'
logger.info(log_str, ranks=[0]) logger.info(log_str, ranks=[0])
return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0] return nn.ModuleList(module_list) if len(module_list) > 1 else module_list[0]
\ No newline at end of file
...@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE ...@@ -20,14 +20,12 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE SOFTWARE
*/ */
#include "cpu_adam.h" #include "cpu_adam.h"
#include <iostream>
#include <math.h> #include <math.h>
#include <memory>
#include <omp.h> #include <omp.h>
#include <string.h> #include <string.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <iostream>
#include <memory>
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
...@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, ...@@ -84,7 +82,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg, ...@@ -146,7 +145,8 @@ void Adam_Optimizer::Step_1(float *_params, float *grads, float *_exp_avg,
if (_param_size > rounded_size) { if (_param_size > rounded_size) {
for (size_t t = rounded_size; t < _param_size; t += TILE) { for (size_t t = rounded_size; t < _param_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > _param_size) copy_size = _param_size - t; if ((t + TILE) > _param_size)
copy_size = _param_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg, ...@@ -235,7 +235,8 @@ void Adam_Optimizer::Step_4(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3, ...@@ -320,6 +321,7 @@ int create_adam_optimizer(int optimizer_id, float alpha = 1e-3,
s_optimizers[optimizer_id] = opt; s_optimizers[optimizer_id] = opt;
if (should_log) { if (should_log) {
std::string avx_type = ""; std::string avx_type = "";
#if defined(__AVX512__) #if defined(__AVX512__)
avx_type = "AVX512"; avx_type = "AVX512";
...@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, ...@@ -384,7 +386,8 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
for (size_t t = 0; t < rounded_size; t += TILE) { for (size_t t = 0; t < rounded_size; t += TILE) {
size_t copy_size = TILE; size_t copy_size = TILE;
if ((t + TILE) > rounded_size) copy_size = rounded_size - t; if ((t + TILE) > rounded_size)
copy_size = rounded_size - t;
size_t offset = copy_size + t; size_t offset = copy_size + t;
#pragma omp parallel for #pragma omp parallel for
...@@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg, ...@@ -460,29 +463,43 @@ void Adam_Optimizer::Step_8(float *_params, float *grads, float *_exp_avg,
grad_half_precision, loss_scale); grad_half_precision, loss_scale);
} }
int adam_step(int optimizer_id, size_t step, float lr, float beta1, float beta2, int adam_step(int optimizer_id,
float epsilon, float weight_decay, bool bias_correction, size_t step,
torch::Tensor &params, torch::Tensor &grads, float lr,
torch::Tensor &exp_avg, torch::Tensor &exp_avg_sq, float beta1,
float loss_scale) { float beta2,
auto params_c = params.contiguous(); float epsilon,
auto grads_c = grads.contiguous(); float weight_decay,
auto exp_avg_c = exp_avg.contiguous(); bool bias_correction,
auto exp_avg_sq_c = exp_avg_sq.contiguous(); torch::Tensor& params,
torch::Tensor& grads,
float *params_ptr = (float *)params_c.data_ptr(); torch::Tensor& exp_avg,
float *grads_ptr = (float *)grads_c.data_ptr(); torch::Tensor& exp_avg_sq,
float *exp_avg_ptr = (float *)exp_avg_c.data_ptr(); float loss_scale)
float *exp_avg_sq_ptr = (float *)exp_avg_sq_c.data_ptr(); {
std::shared_ptr<Adam_Optimizer> opt = auto params_c = params.contiguous();
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]); auto grads_c = grads.contiguous();
opt->IncrementStep(step, beta1, beta2); auto exp_avg_c = exp_avg.contiguous();
opt->update_state(lr, epsilon, weight_decay, bias_correction); auto exp_avg_sq_c = exp_avg_sq.contiguous();
opt->Step_8(params_ptr, grads_ptr, exp_avg_ptr, exp_avg_sq_ptr,
params_c.numel(), (params.options().dtype() == at::kHalf), float* params_ptr = (float*)params_c.data_ptr();
(grads.options().dtype() == at::kHalf), loss_scale); float* grads_ptr = (float*)grads_c.data_ptr();
float* exp_avg_ptr = (float*)exp_avg_c.data_ptr();
return 0; float* exp_avg_sq_ptr = (float*)exp_avg_sq_c.data_ptr();
std::shared_ptr<Adam_Optimizer> opt =
std::static_pointer_cast<Adam_Optimizer>(s_optimizers[optimizer_id]);
opt->IncrementStep(step, beta1, beta2);
opt->update_state(lr, epsilon, weight_decay, bias_correction);
opt->Step_8(params_ptr,
grads_ptr,
exp_avg_ptr,
exp_avg_sq_ptr,
params_c.numel(),
(params.options().dtype() == at::kHalf),
(grads.options().dtype() == at::kHalf),
loss_scale);
return 0;
} }
int destroy_adam_optimizer(int optimizer_id) { int destroy_adam_optimizer(int optimizer_id) {
......
...@@ -48,10 +48,10 @@ SOFTWARE ...@@ -48,10 +48,10 @@ SOFTWARE
#define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c) #define SIMD_FMA(x, y, c) _mm512_fmadd_ps(x, y, c)
#define SIMD_SQRT(x) _mm512_sqrt_ps(x) #define SIMD_SQRT(x) _mm512_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm512_div_ps(x, y) #define SIMD_DIV(x, y) _mm512_div_ps(x, y)
#define SIMD_LOAD_HALF(x) \ #define SIMD_LOAD_HALF(x) \
_mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x))) _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm256_store_ps( \ _mm256_store_ps( \
x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm256_castsi256_ps(_mm512_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#elif defined(__AVX256__) or defined(__AVX2__) #elif defined(__AVX256__) or defined(__AVX2__)
...@@ -66,8 +66,8 @@ SOFTWARE ...@@ -66,8 +66,8 @@ SOFTWARE
#define SIMD_SQRT(x) _mm256_sqrt_ps(x) #define SIMD_SQRT(x) _mm256_sqrt_ps(x)
#define SIMD_DIV(x, y) _mm256_div_ps(x, y) #define SIMD_DIV(x, y) _mm256_div_ps(x, y)
#define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define SIMD_LOAD_HALF(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
#define SIMD_STORE_HALF(x, d) \ #define SIMD_STORE_HALF(x, d) \
_mm_store_ps( \ _mm_store_ps( \
x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT))) x, _mm_castsi128_ps(_mm256_cvtps_ph(d, _MM_FROUND_TO_NEAREST_INT)))
#endif #endif
...@@ -83,25 +83,19 @@ union AVX_Data { ...@@ -83,25 +83,19 @@ union AVX_Data {
#endif #endif
#define STEP(SPAN) \ #define STEP(SPAN) \
void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \ void Step_##SPAN(float *_params, float *grads, float *_exp_avg, \
float *_exp_avg_sq, size_t _param_size, \ float *_exp_avg_sq, size_t _param_size, \
bool param_half_precision = false, \ bool param_half_precision = false, \
bool grad_half_precision = false, float loss_scale = -1); bool grad_half_precision = false, float loss_scale = -1);
class Adam_Optimizer { class Adam_Optimizer {
public: public:
Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999, Adam_Optimizer(float alpha = 1e-3, float betta1 = 0.9, float betta2 = 0.999,
float eps = 1e-8, float weight_decay = 0, float eps = 1e-8, float weight_decay = 0,
bool adamw_mode = true) bool adamw_mode = true)
: _alpha(alpha), : _alpha(alpha), _betta1(betta1), _betta2(betta2), _eps(eps),
_betta1(betta1), _weight_decay(weight_decay), _betta1_t(1.0), _betta2_t(1.0), _step(0),
_betta2(betta2),
_eps(eps),
_weight_decay(weight_decay),
_betta1_t(1.0),
_betta2_t(1.0),
_step(0),
_adamw_mode(adamw_mode) {} _adamw_mode(adamw_mode) {}
~Adam_Optimizer() {} ~Adam_Optimizer() {}
...@@ -141,7 +135,7 @@ class Adam_Optimizer { ...@@ -141,7 +135,7 @@ class Adam_Optimizer {
} }
} }
private: private:
float _alpha; float _alpha;
float _betta1; float _betta1;
float _betta2; float _betta2;
......
...@@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel( ...@@ -16,7 +16,7 @@ __global__ void ls_cross_entropy_fw_kernel(
const int left_idx = block_start + threadIdx.x; const int left_idx = block_start + threadIdx.x;
const int right_idx = (blockIdx.x + 1) * vocab_size; const int right_idx = (blockIdx.x + 1) * vocab_size;
float max_input[1] = {REDUCE_FLOAT_INF_NEG}; float max_input[1] = {REDUCE_FLOAT_INF_NEG};
float sum_logits[2] = {0.f, 0.f}; // logit and logit exp float sum_logits[2] = {0.f, 0.f}; // logit and logit exp
int target_tid = targets[blockIdx.x]; int target_tid = targets[blockIdx.x];
if (target_tid == padding_idx) { if (target_tid == padding_idx) {
......
#include <cooperative_groups.h>
#include <chrono> #include <chrono>
#include <ctime> #include <ctime>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
curandStatePhilox4_32_10_t *curandstate; curandStatePhilox4_32_10_t *curandstate;
...@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -165,7 +165,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio, ...@@ -201,7 +202,8 @@ __global__ void ls_dropout_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -259,7 +261,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
uint8_t m[4]; uint8_t m[4];
...@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio, ...@@ -286,7 +289,8 @@ __global__ void ls_dropout_bwd_kernel(const int total_count, const float ratio,
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
float4 *out4 = reinterpret_cast<float4 *>(out); float4 *out4 = reinterpret_cast<float4 *>(out);
const float4 *vals_float4 = reinterpret_cast<const float4 *>(in); const float4 *vals_float4 = reinterpret_cast<const float4 *>(in);
...@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -376,7 +380,8 @@ __global__ void ls_dropout_res_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel( ...@@ -419,7 +424,8 @@ __global__ void ls_dropout_res_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -559,9 +565,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < 32; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < 32; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0)
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel( ...@@ -613,9 +621,11 @@ __global__ void ls_dropout_bias_bwd_kernel(
} }
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (y == 0) tile[0][x] = sum; if (y == 0)
tile[0][x] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.x < 8) { if (threadIdx.x < 8) {
...@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -679,7 +689,8 @@ __global__ void ls_dropout_act_bias_kernel(
const float scale = 1.f / (1.f - ratio); const float scale = 1.f / (1.f - ratio);
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 4 >= total_count) return; if (i * 4 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel( ...@@ -724,7 +735,8 @@ __global__ void ls_dropout_act_bias_kernel(
int i = blockIdx.x * blockDim.x + threadIdx.x; int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i * 8 >= total_count) return; if (i * 8 >= total_count)
return;
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, i, 0, &state); curand_init(seed, i, 0, &state);
...@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel( ...@@ -885,9 +897,11 @@ __global__ void ls_dropout_act_bias_bwd_kernel(
float sum = tile[threadIdx.y][threadIdx.x]; float sum = tile[threadIdx.y][threadIdx.x];
__syncthreads(); __syncthreads();
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_down(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_down(sum, i);
if (threadIdx.x == 0) tile[0][threadIdx.y] = sum; if (threadIdx.x == 0)
tile[0][threadIdx.y] = sum;
__syncthreads(); __syncthreads();
if (threadIdx.y == 0) { if (threadIdx.y == 0) {
......
#include <cooperative_groups.h>
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
/** /**
......
...@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f; ...@@ -13,23 +13,22 @@ const float REDUCE_FLOAT_INF_NEG = -100000000.f;
const float REDUCE_FLOAT_INF_POS = 100000000.f; const float REDUCE_FLOAT_INF_POS = 100000000.f;
const unsigned int WARP_REDUCE_SIZE = 32; const unsigned int WARP_REDUCE_SIZE = 32;
template <typename T> template <typename T> __forceinline__ __device__ T warpReduceSum(T val) {
__forceinline__ __device__ T warpReduceSum(T val) {
for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1) for (int mask = (WARP_REDUCE_SIZE >> 1); mask > 0; mask >>= 1)
val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE); val += __shfl_xor_sync(WARP_REDUCE_MASK, val, mask, WARP_REDUCE_SIZE);
return val; return val;
} }
/* Calculate the sum of all elements in a block */ /* Calculate the sum of all elements in a block */
template <typename T> template <typename T> __forceinline__ __device__ T blockReduceSum(T val) {
__forceinline__ __device__ T blockReduceSum(T val) {
static __shared__ T shared[32]; static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f; int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5; int wid = threadIdx.x >> 5;
val = warpReduceSum<T>(val); val = warpReduceSum<T>(val);
if (lane == 0) shared[wid] = val; if (lane == 0)
shared[wid] = val;
__syncthreads(); __syncthreads();
val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f; val = (threadIdx.x < (blockDim.x >> 5)) ? shared[lane] : (T)0.0f;
...@@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) { ...@@ -57,10 +56,10 @@ __inline__ __device__ void warpReduce<ReduceType::kMax, 1>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kMax, 2>(float *pval) {
float val0_tmp, val1_tmp; float val0_tmp, val1_tmp;
#define WarpReduceMaxOneStep(a, b) \ #define WarpReduceMaxOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval) = max(val0_tmp, *(pval)); \ *(pval) = max(val0_tmp, *(pval)); \
*(pval + 1) = max(val1_tmp, *(pval + 1)); *(pval + 1) = max(val1_tmp, *(pval + 1));
WarpReduceMaxOneStep(16, 32); WarpReduceMaxOneStep(16, 32);
...@@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) { ...@@ -89,10 +88,10 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 1>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
float val0_tmp, val1_tmp; float val0_tmp, val1_tmp;
#define WarpReduceSumOneStep(a, b) \ #define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
*(pval + 0) += val0_tmp; \ *(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp *(pval + 1) += val1_tmp
WarpReduceSumOneStep(16, 32); WarpReduceSumOneStep(16, 32);
...@@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) { ...@@ -107,14 +106,14 @@ __inline__ __device__ void warpReduce<ReduceType::kSum, 2>(float *pval) {
template <> template <>
__inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) { __inline__ __device__ void warpReduce<ReduceType::kSum, 4>(float *pval) {
float val0_tmp, val1_tmp, val2_tmp, val3_tmp; float val0_tmp, val1_tmp, val2_tmp, val3_tmp;
#define WarpReduceSumOneStep(a, b) \ #define WarpReduceSumOneStep(a, b) \
val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \ val0_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 0), a, b); \
val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \ val1_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 1), a, b); \
val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \ val2_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 2), a, b); \
val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \ val3_tmp = __shfl_xor_sync(WARP_REDUCE_MASK, *(pval + 3), a, b); \
*(pval + 0) += val0_tmp; \ *(pval + 0) += val0_tmp; \
*(pval + 1) += val1_tmp; \ *(pval + 1) += val1_tmp; \
*(pval + 2) += val2_tmp; \ *(pval + 2) += val2_tmp; \
*(pval + 3) += val3_tmp *(pval + 3) += val3_tmp
WarpReduceSumOneStep(16, 32); WarpReduceSumOneStep(16, 32);
......
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
#include "cuda_util.h" #include "cuda_util.h"
class Context { class Context {
public: public:
Context() : _stream(nullptr) { Context() : _stream(nullptr) {
CHECK_GPU_ERROR(cublasCreate(&_cublasHandle)); CHECK_GPU_ERROR(cublasCreate(&_cublasHandle));
} }
...@@ -30,7 +30,7 @@ class Context { ...@@ -30,7 +30,7 @@ class Context {
cublasHandle_t get_cublashandle() { return _cublasHandle; } cublasHandle_t get_cublashandle() { return _cublasHandle; }
private: private:
cudaStream_t _stream; cudaStream_t _stream;
cublasHandle_t _cublasHandle; cublasHandle_t _cublasHandle;
}; };
...@@ -8,9 +8,8 @@ ...@@ -8,9 +8,8 @@
#include "cuda_util.h" #include "cuda_util.h"
template <typename T> template <typename T> class CrossEntropyLayer {
class CrossEntropyLayer { public:
public:
CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens); CrossEntropyLayer(float epsilon, int padding_idx, int max_batch_tokens);
virtual ~CrossEntropyLayer(); virtual ~CrossEntropyLayer();
...@@ -23,7 +22,7 @@ class CrossEntropyLayer { ...@@ -23,7 +22,7 @@ class CrossEntropyLayer {
void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size); void set_cur_batch_shape(int batch_size, int seq_len, int vocab_size);
private: private:
void allocate_mem_buffer() { void allocate_mem_buffer() {
// allocate local gpu memory // allocate local gpu memory
_loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2); _loss_buffer = cuda_malloc<float>(_max_batch_tokens * 2);
......
...@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file, ...@@ -20,8 +20,7 @@ void check_gpu_error(T result, char const *const func, const char *const file,
template <typename T> template <typename T>
void print_vec(const T *outv, std::string outn, int num_output_ele); void print_vec(const T *outv, std::string outn, int num_output_ele);
template <typename T> template <typename T> T *cuda_malloc(size_t ele_num);
T *cuda_malloc(size_t ele_num);
void cuda_free(void *pdata); void cuda_free(void *pdata);
...@@ -29,6 +28,6 @@ template <typename T> ...@@ -29,6 +28,6 @@ template <typename T>
void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf, void check_nan_inf(const T *data_ptr, int dsize, bool check_nan_inf,
std::string file, int line, cudaStream_t stream); std::string file, int line, cudaStream_t stream);
#define CHECK_NAN_INF(ptr, size, stream) \ #define CHECK_NAN_INF(ptr, size, stream) \
check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \ check_nan_inf((ptr), (size), true, __FILE__, __LINE__, (stream)); \
check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream)) check_nan_inf((ptr), (size), false, __FILE__, __LINE__, (stream))
...@@ -3,14 +3,12 @@ ...@@ -3,14 +3,12 @@
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <stdio.h> #include <stdio.h>
#include <string> #include <string>
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T> class Dropout {
class Dropout { public:
public:
struct Config { struct Config {
float ratio; float ratio;
bool training; bool training;
...@@ -90,7 +88,7 @@ class Dropout { ...@@ -90,7 +88,7 @@ class Dropout {
void SetTrainingMode(bool training) { _config.training = training; } void SetTrainingMode(bool training) { _config.training = training; }
private: private:
uint8_t *_mask; uint8_t *_mask;
Config _config; Config _config;
}; };
...@@ -13,16 +13,14 @@ ...@@ -13,16 +13,14 @@
#include "cublas_wrappers.h" #include "cublas_wrappers.h"
#include "kernels.h" #include "kernels.h"
template <typename T> template <typename T> class FeedForward {
class FeedForward { public:
public:
struct Config { struct Config {
int outputSize; int outputSize;
int inputSize; int inputSize;
std::array<int, 3> gemm_algos; std::array<int, 3> gemm_algos;
Config(int outputs, int inputs) Config(int outputs, int inputs)
: outputSize(outputs), : outputSize(outputs), inputSize(inputs),
inputSize(inputs),
gemm_algos(std::array<int, 3>({99, 99, 99})) {} gemm_algos(std::array<int, 3>({99, 99, 99})) {}
}; };
...@@ -63,6 +61,6 @@ class FeedForward { ...@@ -63,6 +61,6 @@ class FeedForward {
config_.inputSize = inputSize; config_.inputSize = inputSize;
} }
private: private:
Config config_; Config config_;
}; };
...@@ -10,9 +10,8 @@ ...@@ -10,9 +10,8 @@
using namespace std; using namespace std;
template <typename T> template <typename T> class Softmax {
class Softmax { public:
public:
struct Config { struct Config {
size_t nhead; size_t nhead;
Config(size_t nhead) : nhead(nhead) {} Config(size_t nhead) : nhead(nhead) {}
...@@ -37,6 +36,6 @@ class Softmax { ...@@ -37,6 +36,6 @@ class Softmax {
void reset_size(size_t nhead) { config_.nhead = nhead; } void reset_size(size_t nhead) { config_.nhead = nhead; }
private: private:
Config config_; Config config_;
}; };
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h> #include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
......
#include <cooperative_groups.h>
#include <math.h> #include <math.h>
#include <cub/block/block_load.cuh> #include <cub/block/block_load.cuh>
...@@ -7,6 +6,8 @@ ...@@ -7,6 +6,8 @@
#include "block_reduce.h" #include "block_reduce.h"
#include "kernels.h" #include "kernels.h"
#include <cooperative_groups.h>
namespace cg = cooperative_groups; namespace cg = cooperative_groups;
const float EPSILON = 1e-8f; const float EPSILON = 1e-8f;
...@@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len, ...@@ -119,7 +120,7 @@ __global__ void ker_attn_softmax(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
template <typename T, int block_dim, int ele_per_thread> template <typename T, int block_dim, int ele_per_thread>
...@@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len, ...@@ -197,7 +198,7 @@ __global__ void ker_attn_softmax_lt32(T *inp, const T *attn_mask, int from_len,
BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i], BlockStore(ts_store).Store(inp + (token_id + i) * to_len, inp_val[i],
to_len); to_len);
} }
} // blockIdx.x } // blockIdx.x
} }
/* /*
...@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) { ...@@ -303,7 +304,8 @@ __global__ void ker_attn_softmax_bw(T *grad, const T *inp, int softmax_length) {
cg::thread_block b = cg::this_thread_block(); cg::thread_block b = cg::this_thread_block();
cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b); cg::thread_block_tile<WARP_SIZE> g = cg::tiled_partition<WARP_SIZE>(b);
for (int i = 1; i < WARP_SIZE; i <<= 1) sum += g.shfl_xor(sum, i); for (int i = 1; i < WARP_SIZE; i <<= 1)
sum += g.shfl_xor(sum, i);
#pragma unroll #pragma unroll
for (int i = 0; i < ITERATIONS; ++i) { for (int i = 0; i < ITERATIONS; ++i) {
......
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
* https://github.com/NVIDIA/apex * https://github.com/NVIDIA/apex
* with minor changes. */ * with minor changes. */
#include <torch/extension.h> #include "compat.h"
#include <cassert> #include <cassert>
#include <torch/extension.h>
#include <vector> #include <vector>
#include "compat.h"
namespace { namespace {
void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1, void compute_n1_n2(at::Tensor input, at::IntArrayRef normalized_shape, int &n1,
...@@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape, ...@@ -67,7 +65,7 @@ void check_args(at::Tensor input, at::IntArrayRef normalized_shape,
check_args(input, normalized_shape, n1, n2); check_args(input, normalized_shape, n1, n2);
check_args(normalized_shape, gamma, beta); check_args(normalized_shape, gamma, beta);
} }
} // namespace } // namespace
void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *input, int n1, int n2, at::Tensor *input, int n1, int n2,
...@@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar, ...@@ -75,16 +73,17 @@ void cuda_layer_norm(at::Tensor *output, at::Tensor *mean, at::Tensor *invvar,
at::Tensor *beta, double epsilon); at::Tensor *beta, double epsilon);
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
std::vector<at::Tensor> layer_norm_affine(at::Tensor input, std::vector<at::Tensor> layer_norm_affine(at::Tensor input,
at::IntArrayRef normalized_shape, at::IntArrayRef normalized_shape,
at::Tensor gamma, at::Tensor beta, at::Tensor gamma, at::Tensor beta,
double epsilon) { double epsilon) {
CHECK_INPUT(input); CHECK_INPUT(input);
CHECK_INPUT(gamma); CHECK_INPUT(gamma);
CHECK_INPUT(beta); CHECK_INPUT(beta);
...@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean, ...@@ -110,10 +109,11 @@ void cuda_layer_norm_gradient(at::Tensor *dout, at::Tensor *mean,
double epsilon, at::Tensor *grad_input, double epsilon, at::Tensor *grad_input,
at::Tensor *grad_gamma, at::Tensor *grad_beta); at::Tensor *grad_gamma, at::Tensor *grad_beta);
std::vector<at::Tensor> layer_norm_gradient_affine( std::vector<at::Tensor>
at::Tensor dout, at::Tensor mean, at::Tensor invvar, at::Tensor input, layer_norm_gradient_affine(at::Tensor dout, at::Tensor mean, at::Tensor invvar,
at::IntArrayRef normalized_shape, at::Tensor gamma, at::Tensor beta, at::Tensor input, at::IntArrayRef normalized_shape,
double epsilon) { at::Tensor gamma, at::Tensor beta, double epsilon) {
CHECK_INPUT(dout); CHECK_INPUT(dout);
CHECK_INPUT(mean); CHECK_INPUT(mean);
CHECK_INPUT(invvar); CHECK_INPUT(invvar);
......
...@@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -15,24 +15,25 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx); torch::Tensor dest_idx);
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, torch::Tensor tokens_grad, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor dest_idx); torch::Tensor mask, torch::Tensor dest_idx);
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask); torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask);
#define CHECK_CUDA(x) \ #define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \ #define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \ #define CHECK_INPUT(x) \
CHECK_CUDA(x); \ CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x) CHECK_CONTIGUOUS(x)
torch::Tensor moe_dispatch_forward(int s, int ec, int h, torch::Tensor moe_dispatch_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
CHECK_INPUT(batch_tokens); CHECK_INPUT(batch_tokens);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
...@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h, ...@@ -44,6 +45,7 @@ torch::Tensor moe_dispatch_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_grad); CHECK_INPUT(expert_grad);
CHECK_CUDA(mask); CHECK_CUDA(mask);
CHECK_CUDA(dest_idx); CHECK_CUDA(dest_idx);
...@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, ...@@ -55,6 +57,7 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
CHECK_INPUT(expert_tokens); CHECK_INPUT(expert_tokens);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
...@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h, ...@@ -64,12 +67,11 @@ torch::Tensor moe_combine_forward(int s, int e, int c, int h,
dest_idx); dest_idx);
} }
std::vector<torch::Tensor> moe_combine_backward(int s, int e, int c, int h, std::vector<torch::Tensor>
torch::Tensor tokens_grad, moe_combine_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor logits, torch::Tensor mask, torch::Tensor dest_idx) {
torch::Tensor mask,
torch::Tensor dest_idx) {
CHECK_INPUT(tokens_grad); CHECK_INPUT(tokens_grad);
CHECK_INPUT(logits); CHECK_INPUT(logits);
CHECK_CUDA(mask); CHECK_CUDA(mask);
......
#include "block_reduce.h"
#include <cub/cub.cuh>
#include <cuda.h> #include <cuda.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <cub/cub.cuh>
#include "block_reduce.h"
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) { ...@@ -29,6 +28,7 @@ __device__ void moe_dpch_one_fwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) { ...@@ -51,6 +51,7 @@ __device__ void moe_dpch_one_bwd(T *src_row, T *dst_row, const int cols) {
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2, ...@@ -74,6 +75,7 @@ __device__ void moe_dpch_two_fwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2, ...@@ -103,6 +105,7 @@ __device__ void moe_dpch_two_bwd(T *src_row, T *dst_row1, T *dst_row2,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight, ...@@ -131,6 +134,7 @@ __device__ void moe_cb_one_fwd(T *src_row, T *dst_row, const T weight,
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
T *weight_grad, const T weight, const int cols) { T *weight_grad, const T weight, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row, ...@@ -160,13 +164,15 @@ __device__ void moe_cb_one_bwd(T *src_row, T *dst_row, T *tks_row,
blockReduce<ReduceType::kSum, 1>(&thread_sum); blockReduce<ReduceType::kSum, 1>(&thread_sum);
if (threadIdx.x == 0) *weight_grad = static_cast<T>(thread_sum); if (threadIdx.x == 0)
*weight_grad = static_cast<T>(thread_sum);
} }
template <typename T, int block_size, int pack_size> template <typename T, int block_size, int pack_size>
__device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row, __device__ void moe_cb_two_fwd(T *src_row1, T *src_row2, T *dst_row,
const T weight1, const T weight2, const T weight1, const T weight2,
const int cols) { const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row, ...@@ -198,6 +204,7 @@ __device__ void moe_cb_two_bwd(T *src_row1, T *src_row2, T *dst_row,
T *tks_row1, T *tks_row2, T *weight_grad1, T *tks_row1, T *tks_row2, T *weight_grad1,
T *weight_grad2, const T weight1, T *weight_grad2, const T weight1,
const T weight2, const int cols) { const T weight2, const int cols) {
assert(cols % pack_size == 0); assert(cols % pack_size == 0);
const int bpack_size = block_size * pack_size; const int bpack_size = block_size * pack_size;
...@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size> ...@@ -244,6 +251,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_fwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_fwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
...@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size> ...@@ -259,6 +267,7 @@ template <typename T, int block_size, int pack_size>
__device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2, __device__ void moe_dpch_bwd_selector(T *src_row, T *dst_row1, T *dst_row2,
const int cols, const int indicator1, const int cols, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2, moe_dpch_two_bwd<T, block_size, pack_size>(src_row, dst_row1, dst_row2,
cols); cols);
...@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size> ...@@ -274,6 +283,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input, __global__ void moe_dpch_fwd_kernel(T *batch_tokens, T *expert_input,
int *mask1, int *mask2, int *dest1, int *mask1, int *mask2, int *dest1,
int *dest2, const int h) { int *dest2, const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_fwd_selector<T, block_size, pack_size>( moe_dpch_fwd_selector<T, block_size, pack_size>(
...@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size> ...@@ -285,6 +295,7 @@ template <typename T, int block_size, int pack_size>
__global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1, __global__ void moe_dpch_bwd_kernel(T *tokens_grad, T *expert_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int h) { const int h) {
int row = blockIdx.x; int row = blockIdx.x;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
moe_dpch_bwd_selector<T, block_size, pack_size>( moe_dpch_bwd_selector<T, block_size, pack_size>(
...@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row, ...@@ -299,6 +310,7 @@ __device__ void moe_cb_fwd_selector(T *src_row1, T *src_row2, T *dst_row,
const int cols, const T weight1, const int cols, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_fwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
weight1, weight2, cols); weight1, weight2, cols);
...@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row, ...@@ -316,6 +328,7 @@ __device__ void moe_cb_bwd_selector(T *src_row1, T *src_row2, T *dst_row,
T *wt_grad1, T *wt_grad2, const T weight1, T *wt_grad1, T *wt_grad2, const T weight1,
const T weight2, const int indicator1, const T weight2, const int indicator1,
const int indicator2) { const int indicator2) {
if (indicator1 != 0 && indicator2 != 0) if (indicator1 != 0 && indicator2 != 0)
moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row, moe_cb_two_bwd<T, block_size, pack_size>(src_row1, src_row2, dst_row,
tks_row1, tks_row2, wt_grad1, tks_row1, tks_row2, wt_grad1,
...@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens, ...@@ -335,6 +348,7 @@ __global__ void moe_cb_fwd_kernel(T *expert_tokens, T *combine_tokens,
T *logits, int *mask1, int *mask2, int *dest1, T *logits, int *mask1, int *mask2, int *dest1,
int *dest2, const int e, const int c, int *dest2, const int e, const int c,
const int h) { const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e); T *row_log = logits + (row * e);
...@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, ...@@ -349,6 +363,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
T *logits, T *logits_grad, int *mask1, T *logits, T *logits_grad, int *mask1,
int *mask2, int *dest1, int *dest2, int *mask2, int *dest1, int *dest2,
const int e, const int c, const int h) { const int e, const int c, const int h) {
int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c; int row = blockIdx.x, eid1 = dest1[row] / c, eid2 = dest2[row] / c;
int indicator2 = mask2 == nullptr ? 0 : mask2[row]; int indicator2 = mask2 == nullptr ? 0 : mask2[row];
T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e); T *row_log = logits + (row * e), *row_grad = logits_grad + (row * e);
...@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks, ...@@ -364,6 +379,7 @@ __global__ void moe_cb_bwd_kernel(T *tokens_grad, T *expert_grad, T *tks,
template <int block_size, int pack_size> template <int block_size, int pack_size>
__global__ void cumsum_kernel(int *inputs, int *outputs, const int s, __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
const int e) { const int e) {
assert(s % pack_size == 0); assert(s % pack_size == 0);
constexpr int bpack_size = block_size * pack_size; constexpr int bpack_size = block_size * pack_size;
int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1; int tid = threadIdx.x, bid = blockIdx.x, tps = tid * pack_size, last_sum = -1;
...@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s, ...@@ -410,7 +426,8 @@ __global__ void cumsum_kernel(int *inputs, int *outputs, const int s,
} }
__syncthreads(); __syncthreads();
if (tid == 0) temp[0] = temp[block_size]; if (tid == 0)
temp[0] = temp[block_size];
__syncthreads(); __syncthreads();
if (idx + tps < s) { if (idx + tps < s) {
...@@ -436,6 +453,7 @@ template <typename T> ...@@ -436,6 +453,7 @@ template <typename T>
void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
int *mask2, int *dest1, int *dest2, const int s, int *mask2, int *dest1, int *dest2, const int s,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_dpch_fwd_kernel<T, 32, 4> moe_dpch_fwd_kernel<T, 32, 4>
<<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(batch_tokens, expert_input, mask1, mask2, dest1, dest2, h);
...@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1, ...@@ -456,6 +474,7 @@ void moe_dpch_fwd_launch(T *batch_tokens, T *expert_input, int *mask1,
template <typename T> template <typename T>
void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2, void moe_dpch_bwd_launch(T *tokens_grad, T *expert_grad, int *mask1, int *mask2,
int *dest1, int *dest2, const int s, const int h) { int *dest1, int *dest2, const int s, const int h) {
if (h < 256) if (h < 256)
moe_dpch_bwd_kernel<T, 32, 4> moe_dpch_bwd_kernel<T, 32, 4>
<<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h); <<<s, 32>>>(tokens_grad, expert_grad, mask1, mask2, dest1, dest2, h);
...@@ -477,6 +496,7 @@ template <typename T> ...@@ -477,6 +496,7 @@ template <typename T>
void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits, void moe_cb_fwd_launch(T *expert_tokens, T *combine_tokens, T *logits,
int *mask1, int *mask2, int *dest1, int *dest2, int *mask1, int *mask2, int *dest1, int *dest2,
const int s, const int e, const int c, const int h) { const int s, const int e, const int c, const int h) {
if (h < 256) if (h < 256)
moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens, moe_cb_fwd_kernel<T, 32, 4><<<s, 32>>>(expert_tokens, combine_tokens,
logits, mask1, mask2, dest1, dest2, logits, mask1, mask2, dest1, dest2,
...@@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, ...@@ -504,11 +524,12 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
T *logits_grad, int *mask1, int *mask2, int *dest1, T *logits_grad, int *mask1, int *mask2, int *dest1,
int *dest2, const int s, const int e, const int c, int *dest2, const int s, const int e, const int c,
const int h) { const int h) {
if (h < 256) if (h < 256)
moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks, moe_cb_bwd_kernel<T, 32, 4><<<s, 32>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2, logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h); dest1, dest2, e, c, h);
else // if (h < 512) else // if (h < 512)
moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks, moe_cb_bwd_kernel<T, 64, 4><<<s, 64>>>(tokens_grad, expert_grad, tks,
logits, logits_grad, mask1, mask2, logits, logits_grad, mask1, mask2,
dest1, dest2, e, c, h); dest1, dest2, e, c, h);
...@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits, ...@@ -523,6 +544,7 @@ void moe_cb_bwd_launch(T *tokens_grad, T *expert_grad, T *tks, T *logits,
} }
void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
if (s <= 256) if (s <= 256)
cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e); cumsum_kernel<256, 1><<<e, 256>>>(inputs, outputs, s, e);
else if (s <= 512) else if (s <= 512)
...@@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) { ...@@ -537,26 +559,27 @@ void cumsum_launch(int *inputs, int *outputs, const int s, const int e) {
// API FUNCTIONS -------------------------------- // API FUNCTIONS --------------------------------
#define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \ #define DISPATCH_FLOAT_AND_HALF(TYPE, NAME, ...) \
switch (TYPE) { \ switch (TYPE) { \
case at::ScalarType::Float: { \ case at::ScalarType::Float: { \
using scalar_t = float; \ using scalar_t = float; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
case at::ScalarType::Half: { \ case at::ScalarType::Half: { \
using scalar_t = at::Half; \ using scalar_t = at::Half; \
__VA_ARGS__; \ __VA_ARGS__; \
break; \ break; \
} \ } \
default: \ default: \
AT_ERROR(#NAME, " not implemented yet for specific data type."); \ AT_ERROR(#NAME, " not implemented yet for specific data type."); \
} }
torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h, torch::Tensor moe_dispatch_cuda_forward(int s, int ec, int h,
torch::Tensor batch_tokens, torch::Tensor batch_tokens,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{ec, h}, {ec, h},
...@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h, ...@@ -578,6 +601,7 @@ torch::Tensor moe_dispatch_cuda_backward(int s, int ec, int h,
torch::Tensor expert_grad, torch::Tensor expert_grad,
torch::Tensor mask, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
auto res = torch::zeros( auto res = torch::zeros(
{s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device())); {s, h}, torch::dtype(expert_grad.dtype()).device(expert_grad.device()));
...@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -598,6 +622,7 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
torch::Tensor expert_tokens, torch::Tensor expert_tokens,
torch::Tensor logits, torch::Tensor mask, torch::Tensor logits, torch::Tensor mask,
torch::Tensor dest_idx) { torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
...@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h, ...@@ -618,10 +643,11 @@ torch::Tensor moe_combine_cuda_forward(int s, int e, int c, int h,
return res; return res;
} }
std::vector<torch::Tensor> moe_combine_cuda_backward( std::vector<torch::Tensor>
int s, int e, int c, int h, torch::Tensor tokens_grad, moe_combine_cuda_backward(int s, int e, int c, int h, torch::Tensor tokens_grad,
torch::Tensor expert_tokens, torch::Tensor logits, torch::Tensor mask, torch::Tensor expert_tokens, torch::Tensor logits,
torch::Tensor dest_idx) { torch::Tensor mask, torch::Tensor dest_idx) {
assert(h % 16 == 0); assert(h % 16 == 0);
assert(tokens_grad.dtype() == expert_tokens.dtype()); assert(tokens_grad.dtype() == expert_tokens.dtype());
assert(expert_tokens.dtype() == logits.dtype()); assert(expert_tokens.dtype() == logits.dtype());
...@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward( ...@@ -647,6 +673,7 @@ std::vector<torch::Tensor> moe_combine_cuda_backward(
} }
torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) { torch::Tensor cumsum_sub_one_in_dim0(torch::Tensor mask) {
assert(mask.dim() == 2); assert(mask.dim() == 2);
assert(mask.dtype() == torch::kInt32); assert(mask.dtype() == torch::kInt32);
......
...@@ -16,8 +16,7 @@ ...@@ -16,8 +16,7 @@
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
#define ILP 4 #define ILP 4
template <typename T> template <typename T> __device__ __forceinline__ bool is_aligned(T *p) {
__device__ __forceinline__ bool is_aligned(T *p) {
return ((uint64_t)p) % (ILP * sizeof(T)) == 0; return ((uint64_t)p) % (ILP * sizeof(T)) == 0;
} }
...@@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset, ...@@ -29,12 +28,11 @@ __device__ __forceinline__ void load_store(T *dst, T *src, int dst_offset,
((LT *)dst)[dst_offset] = ((LT *)src)[src_offset]; ((LT *)dst)[dst_offset] = ((LT *)src)[src_offset];
} }
template <typename x_t> template <typename x_t> struct L2NormFunctor {
struct L2NormFunctor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, float *output, float *output_per_tensor, bool per_tensor,
float *output, float *output_per_tensor, bool per_tensor, int max_chunks_per_tensor) {
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
...@@ -50,8 +48,8 @@ struct L2NormFunctor { ...@@ -50,8 +48,8 @@ struct L2NormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be float
// sure... vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
...@@ -86,14 +84,15 @@ struct L2NormFunctor { ...@@ -86,14 +84,15 @@ struct L2NormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) val += vals[i]; for (int i = 0; i < ILP; i++)
val += vals[i];
float final = reduce_block_into_lanes(s_vals, val); float final = reduce_block_into_lanes(s_vals, val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (!isfinite(final)) if (!isfinite(final))
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] += final; output[blockIdx.x] += final;
if (per_tensor) if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
...@@ -105,12 +104,11 @@ struct L2NormFunctor { ...@@ -105,12 +104,11 @@ struct L2NormFunctor {
// Probably better to template, but since we are not likely to support other // Probably better to template, but since we are not likely to support other
// norm // norm
template <typename x_t> template <typename x_t> struct MaxNormFunctor {
struct MaxNormFunctor { __device__ __forceinline__ void
__device__ __forceinline__ void operator()( operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl,
int chunk_size, volatile int *noop_gmem, TensorListMetadata<1> &tl, float *output, float *output_per_tensor, bool per_tensor,
float *output, float *output_per_tensor, bool per_tensor, int max_chunks_per_tensor) {
int max_chunks_per_tensor) {
// I'd like this kernel to propagate infs/nans. // I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1) // if(*noop_gmem == 1)
// return; // return;
...@@ -126,8 +124,8 @@ struct MaxNormFunctor { ...@@ -126,8 +124,8 @@ struct MaxNormFunctor {
__shared__ float s_vals[512]; __shared__ float s_vals[512];
float vals[ILP]; // = {0}; // this probably works too but I want to be float
// sure... vals[ILP]; // = {0}; // this probably works too but I want to be sure...
x_t r_x[ILP]; x_t r_x[ILP];
for (int i = 0; i < ILP; i++) { for (int i = 0; i < ILP; i++) {
vals[i] = 0.f; vals[i] = 0.f;
...@@ -162,14 +160,15 @@ struct MaxNormFunctor { ...@@ -162,14 +160,15 @@ struct MaxNormFunctor {
} }
float val = 0.f; float val = 0.f;
for (int i = 0; i < ILP; i++) val = fmaxf(fabsf(val), fabsf(vals[i])); for (int i = 0; i < ILP; i++)
val = fmaxf(fabsf(val), fabsf(vals[i]));
float final = reduce_block_into_lanes_max_op(s_vals, val); float final = reduce_block_into_lanes_max_op(s_vals, val);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
if (!isfinite(final)) if (!isfinite(final))
*noop_gmem = *noop_gmem =
1; // Blindly fire off a write. These will race but that's ok. 1; // Blindly fire off a write. These will race but that's ok.
output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final)); output[blockIdx.x] = fmaxf(fabsf(output[blockIdx.x]), fabsf(final));
if (per_tensor) if (per_tensor)
output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) * output_per_tensor[(tl.start_tensor_this_launch + tensor_loc) *
...@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, ...@@ -186,11 +185,13 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x]; if (threadIdx.x < 320)
val = output[threadIdx.x];
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(final); if (threadIdx.x == 0)
*ret = sqrt(final);
} }
if (per_tensor) { if (per_tensor) {
...@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret, ...@@ -203,7 +204,8 @@ __global__ void cleanup(float *output, float *output_per_tensor, float *ret,
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) ret_per_tensor[blockIdx.x] = sqrt(final); if (threadIdx.x == 0)
ret_per_tensor[blockIdx.x] = sqrt(final);
} }
} }
...@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -215,14 +217,17 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
if (blockIdx.x == 0) { if (blockIdx.x == 0) {
float val = 0; float val = 0;
if (threadIdx.x < 320) val = output[threadIdx.x]; if (threadIdx.x < 320)
val = output[threadIdx.x];
if (norm_type == 0) { if (norm_type == 0) {
float final = reduce_block_into_lanes_max_op(vals, val); float final = reduce_block_into_lanes_max_op(vals, val);
if (threadIdx.x == 0) *ret = alpha * (*ret) + beta * final; if (threadIdx.x == 0)
*ret = alpha * (*ret) + beta * final;
} else { } else {
float final = reduce_block_into_lanes(vals, val); float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0) *ret = sqrt(alpha * (*ret) * (*ret) + beta * final); if (threadIdx.x == 0)
*ret = sqrt(alpha * (*ret) * (*ret) + beta * final);
} }
} }
...@@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, ...@@ -255,10 +260,10 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret,
} }
} }
std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda( std::tuple<at::Tensor, at::Tensor>
int chunk_size, at::Tensor noop_flag, multi_tensor_l2norm_cuda(int chunk_size, at::Tensor noop_flag,
std::vector<std::vector<at::Tensor>> tensor_lists, std::vector<std::vector<at::Tensor>> tensor_lists,
at::optional<bool> per_tensor_python) { at::optional<bool> per_tensor_python) {
bool per_tensor = bool per_tensor =
per_tensor_python.has_value() ? per_tensor_python.value() : false; per_tensor_python.has_value() ? per_tensor_python.value() : false;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment