Unverified Commit 0c7d8e3f authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

remove THC headers/functions (#1192)

Changes include
- THC headers removal
- TH macros replacement
- fix some typo in comment
parent 60821f53
...@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in the random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
...@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num ...@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in the random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm.h" #include "batch_norm.h"
#include <cuda.h> #include <cuda.h>
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm_add_relu.h" #include "batch_norm_add_relu.h"
#include <cuda.h> #include <cuda.h>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <cuda.h> #include <cuda.h>
#include "compat.h" #include "compat.h"
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "softmax.h" #include "softmax.h"
#include "dropout.h" #include "dropout.h"
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <THC/THCGeneral.h>
const int UNROLL = 4; const int UNROLL = 4;
template < template <
...@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs, ...@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state //number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
...@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs, ...@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state //number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -86,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,7 +211,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -232,7 +232,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -232,7 +232,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); //CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_q_results, input_lin_q_results,
...@@ -312,10 +312,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -312,10 +312,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -336,7 +336,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -456,7 +456,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -456,7 +456,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -478,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -478,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -499,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -499,7 +499,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -521,7 +521,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -521,7 +521,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -95,7 +95,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -95,7 +95,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -109,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -131,7 +131,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -234,7 +234,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -234,7 +234,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -272,7 +272,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -272,7 +272,7 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q); total_tokens_q);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
...@@ -367,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -367,7 +367,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -378,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -378,7 +378,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -399,7 +399,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -399,7 +399,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -519,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -519,7 +519,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -542,7 +542,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -542,7 +542,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -563,7 +563,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -563,7 +563,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -585,7 +585,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -585,7 +585,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -622,7 +622,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -622,7 +622,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "softmax.h" #include "softmax.h"
#include "dropout.h" #include "dropout.h"
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -82,10 +82,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -82,10 +82,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -173,7 +173,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -173,7 +173,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -194,7 +194,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -194,7 +194,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); //CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -264,10 +264,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -264,10 +264,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -287,7 +287,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -287,7 +287,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -403,7 +403,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -403,7 +403,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -426,7 +426,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -426,7 +426,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -447,7 +447,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -447,7 +447,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -81,10 +81,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -81,10 +81,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -185,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -185,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(
//CUBLAS_GEMM_ALGO1_TENSOR_OP)); //CUBLAS_GEMM_ALGO1_TENSOR_OP));
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -275,10 +275,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -275,10 +275,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -298,7 +298,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -298,7 +298,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -411,7 +411,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -411,7 +411,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -434,7 +434,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -434,7 +434,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -78,9 +78,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -78,9 +78,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -182,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -182,7 +182,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -202,7 +202,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -202,7 +202,7 @@ std::vector<torch::Tensor> fwd_cuda(
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -271,10 +271,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -271,10 +271,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -295,7 +295,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -295,7 +295,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -415,7 +415,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -415,7 +415,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -436,7 +436,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -436,7 +436,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -455,7 +455,7 @@ std::vector<torch::Tensor> bwd_cuda(
embed_dim, embed_dim,
CUDA_R_32F, CUDA_R_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -88,7 +88,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -102,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -102,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -206,7 +206,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -243,7 +243,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -243,7 +243,7 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens); total_tokens);
} }
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
...@@ -327,7 +327,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -327,7 +327,7 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
...@@ -338,7 +338,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -338,7 +338,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -359,7 +359,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -359,7 +359,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -479,7 +479,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -479,7 +479,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -502,7 +502,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -502,7 +502,7 @@ std::vector<torch::Tensor> bwd_cuda(
CUBLAS_GEMM_DEFAULT_TENSOR_OP)); CUBLAS_GEMM_DEFAULT_TENSOR_OP));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(cublasGemmEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmEx(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -540,7 +540,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
THCublasCheck(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH)); TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <iostream> #include <iostream>
//#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include <cuda_profiler_api.h> #include <cuda_profiler_api.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "cutlass/cutlass.h" #include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h" #include "cutlass/gemm/gemm.h"
...@@ -23,7 +23,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -23,7 +23,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
else if (trans == 'n') return CUBLAS_OP_N; else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C; else if (trans == 'c') return CUBLAS_OP_C;
else { else {
THError("trans must be one of: t, n, c"); AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T; return CUBLAS_OP_T;
} }
} }
...@@ -40,7 +40,7 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m, ...@@ -40,7 +40,7 @@ void CublasStridedBatchedGemm(THCState *state, char transa, char transb, long m,
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
//THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH)); //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
THCublasCheck(cublasGemmStridedBatchedEx(handle, TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedEx(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA, (void*)&fAlpha, a, CUDA_R_16F, (int)lda, strideA,
b, CUDA_R_16F, (int)ldb, strideB, b, CUDA_R_16F, (int)ldb, strideB,
...@@ -316,7 +316,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long ...@@ -316,7 +316,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{ {
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX); "with the bound [val] <= %d", INT_MAX);
} }
......
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
#include <cmath> #include <cmath>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
// #include "ATen/Type.h" // #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
......
#include <torch/extension.h>
#include <cuda.h> #include <cuda.h>
#include <curand_kernel.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <c10/macros/Macros.h>
#include <THC/THC.h> #include <torch/extension.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h> #include <c10/macros/Macros.h>
#include "philox.h" #include "philox.h"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
......
#include <torch/extension.h> #include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <vector>
#include <torch/extension.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
template<typename scalar_t> template<typename scalar_t>
......
...@@ -76,9 +76,6 @@ ...@@ -76,9 +76,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh> #include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include "type_shim.h" #include "type_shim.h"
#include "compat.h" #include "compat.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment