Commit 2155dabf authored by Masaki Kozuki's avatar Masaki Kozuki Committed by hubertlu-tw
Browse files

remove THC headers/functions (#1192)

Changes include
- THC headers removal
- TH macros replacement
- fix some typo in comment
 Conflicts:
	apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu
	apex/contrib/csrc/multihead_attn/encdec_multihead_attn_cuda.cu
	apex/contrib/csrc/multihead_attn/encdec_multihead_attn_norm_add_cuda.cu
	apex/contrib/csrc/multihead_attn/masked_softmax_dropout_cuda.cu
	apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_additive_mask_cuda.cu
	apex/contrib/csrc/multihead_attn/self_multihead_attn_bias_cuda.cu
	apex/contrib/csrc/multihead_attn/self_multihead_attn_cuda.cu
	apex/contrib/csrc/multihead_attn/self_multihead_attn_norm_add_cuda.cu
	apex/contrib/csrc/multihead_attn/strided_batched_gemm.h
parent 79a2d204
...@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \ ...@@ -163,7 +163,7 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, total := \
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in the random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
...@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num ...@@ -319,7 +319,7 @@ std::vector<at::Tensor> mha_fwd_nl(const at::Tensor &qkv, // total x num
s.data_ptr(), s.data_ptr(),
p_dropout); p_dropout);
// number of times random will be generated per thread, to offset philox counter in thc random // number of times random will be generated per thread, to offset philox counter in the random
// state // state
int64_t counter_offset = elts_per_thread; int64_t counter_offset = elts_per_thread;
at::PhiloxCudaState rng_engine_inputs; at::PhiloxCudaState rng_engine_inputs;
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm.h" #include "batch_norm.h"
#include <cuda.h> #include <cuda.h>
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h> #include <c10/cuda/CUDACachingAllocator.h>
#include "THC/THC.h"
#include "batch_norm_add_relu.h" #include "batch_norm_add_relu.h"
#include <cuda.h> #include <cuda.h>
......
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <cuda.h> #include <cuda.h>
#include "compat.h" #include "compat.h"
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
<<<<<<< HEAD
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h" #include "THC/THC.h"
=======
#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
>>>>>>> 0c7d8e3 (remove THC headers/functions (#1192))
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "softmax.h" #include "softmax.h"
#include "dropout.h" #include "dropout.h"
......
...@@ -9,8 +9,6 @@ ...@@ -9,8 +9,6 @@
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <curand_kernel.h> #include <curand_kernel.h>
#include <THC/THCGeneral.h>
const int UNROLL = 4; const int UNROLL = 4;
template < template <
...@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs, ...@@ -207,7 +205,7 @@ void apex_fused_dropout_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state //number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
...@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs, ...@@ -245,7 +243,7 @@ void apex_dropout_add_cuda(scalar_t const *inputs,
unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size; unsigned int blocks_per_sm = at::cuda::getCurrentDeviceProperties()->maxThreadsPerMultiProcessor/block_size;
grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x); grid.x = std::min((unsigned int)at::cuda::getCurrentDeviceProperties()->multiProcessorCount * blocks_per_sm, grid.x);
//number of times random will be generated per thread, to offset philox counter in thc random state //number of times random will be generated per thread, to offset philox counter in the random state
int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL; int64_t counter_offset = ((totalElements - 1)/(block_size*grid.x*UNROLL)+1)*UNROLL;
std::pair<uint64_t, uint64_t> rng_engine_inputs; std::pair<uint64_t, uint64_t> rng_engine_inputs;
{ {
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h>
#include "THC/THC.h" #include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -89,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -89,9 +86,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -117,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -117,7 +114,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -230,7 +227,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -230,7 +227,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -254,6 +251,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -254,6 +251,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_q_results, input_lin_q_results,
...@@ -333,8 +331,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -333,8 +331,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -360,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -360,7 +360,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -497,7 +497,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -497,7 +497,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -523,7 +523,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -523,7 +523,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -549,7 +549,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -549,7 +549,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -575,7 +575,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -575,7 +575,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -599,7 +599,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -599,7 +599,7 @@ std::vector<torch::Tensor> bwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
// TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
input_kv_grads, input_kv_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -99,6 +95,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -99,6 +95,7 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Layer Norm // Layer Norm
HostApplyLayerNorm<at::Half,float>( HostApplyLayerNorm<at::Half,float>(
static_cast<at::Half*>(lyr_nrm_results.data_ptr()), static_cast<at::Half*>(lyr_nrm_results.data_ptr()),
...@@ -112,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -112,7 +109,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Q Fwd // Input Linear Q Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_q_dim, output_lin_q_dim,
...@@ -139,7 +136,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -139,7 +136,7 @@ std::vector<torch::Tensor> fwd_cuda(
flags)); flags));
// Input Linear KV Fwd // Input Linear KV Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_kv_dim, output_lin_kv_dim,
...@@ -252,7 +249,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -252,7 +249,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -294,6 +291,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -294,6 +291,8 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens_q); total_tokens_q);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -387,6 +386,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -387,6 +386,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -396,7 +397,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -396,7 +397,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -422,7 +423,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -422,7 +423,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -559,7 +560,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -559,7 +560,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Q Dgrad // Input Linear Q Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -586,7 +587,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -586,7 +587,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Q Wgrad // Input Linear Q Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -612,7 +613,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -612,7 +613,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Dgrad // Input Linear KV Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -638,7 +639,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -638,7 +639,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear KV Wgrad // Input Linear KV Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -680,6 +681,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -680,6 +681,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_q_grads, input_q_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "softmax.h" #include "softmax.h"
#include "dropout.h" #include "dropout.h"
......
#include <vector> #include <vector>
#include <math.h> #include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
...@@ -85,9 +82,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -85,9 +82,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -187,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -187,7 +185,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -211,6 +209,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -211,6 +209,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -280,8 +279,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -280,8 +279,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -307,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -307,7 +308,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -441,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -441,7 +442,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -467,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -467,7 +468,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -493,6 +494,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -493,6 +494,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -84,9 +81,10 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -84,9 +81,10 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
input_lin_results.copy_(input_biases); input_lin_results.copy_(input_biases);
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -199,7 +197,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -199,7 +197,7 @@ std::vector<torch::Tensor> fwd_cuda(
outputs.copy_(output_biases); outputs.copy_(output_biases);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -223,6 +221,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -223,6 +221,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -291,8 +290,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -291,8 +290,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -318,7 +319,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -318,7 +319,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -448,7 +449,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -448,7 +449,7 @@ std::vector<torch::Tensor> bwd_cuda(
batch_stride, batch_stride,
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -474,7 +475,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -474,7 +475,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -500,6 +501,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -500,6 +501,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false); auto input_bias_grads = input_lin_output_grads.view({-1, output_lin_dim}).sum(0, false);
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <cuda_profiler_api.h> //#include <cuda_profiler_api.h>
#include "THC/THC.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -81,8 +78,9 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -81,8 +78,9 @@ std::vector<torch::Tensor> fwd_cuda(
char a_layout_n{'n'}; char a_layout_n{'n'};
char b_layout_n{'n'}; char b_layout_n{'n'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -195,7 +193,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -195,7 +193,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -219,6 +217,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -219,6 +217,7 @@ std::vector<torch::Tensor> fwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_lin_results, input_lin_results,
...@@ -287,8 +286,10 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -287,8 +286,10 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -314,7 +315,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -314,7 +315,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -451,7 +452,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -451,7 +452,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -477,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -477,7 +478,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -501,6 +502,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -501,6 +502,7 @@ std::vector<torch::Tensor> bwd_cuda(
algo, algo,
solution_index, solution_index,
flags)); flags));
TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
#include <vector> #include <vector>
#include <math.h>
#include <iostream> #include <iostream>
//below lines enable hip float to half conversion which are disabled by default in hip_fp16.h
#undef __HIP_NO_HALF_OPERATORS__
#undef __HIP_NO_HALF_CONVERSIONS__
//#endif
#include <ATen/ATen.h>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#include "THC/THC.h" //#include <cuda_profiler_api.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h> #include <torch/extension.h>
#include <math.h>
#include "strided_batched_gemm.h" #include "strided_batched_gemm.h"
#include "softmax.h" #include "softmax.h"
...@@ -106,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -106,7 +102,7 @@ std::vector<torch::Tensor> fwd_cuda(
static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr())); static_cast<const at::Half*>(lyr_nrm_beta_weights.data_ptr()));
// Input Linear Fwd // Input Linear Fwd
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
output_lin_dim, output_lin_dim,
...@@ -221,7 +217,7 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -221,7 +217,7 @@ std::vector<torch::Tensor> fwd_cuda(
attn_batches); attn_batches);
// Output Linear // Output Linear
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_T, CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -264,6 +260,8 @@ std::vector<torch::Tensor> fwd_cuda( ...@@ -264,6 +260,8 @@ std::vector<torch::Tensor> fwd_cuda(
total_tokens); total_tokens);
} }
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
lyr_nrm_results, lyr_nrm_results,
lyr_nrm_mean, lyr_nrm_mean,
...@@ -346,6 +344,8 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -346,6 +344,8 @@ std::vector<torch::Tensor> bwd_cuda(
char b_layout_n{'n'}; char b_layout_n{'n'};
char b_layout_t{'t'}; char b_layout_t{'t'};
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
// Dropout Add Backward // Dropout Add Backward
apex_masked_scale_cuda<at::Half,float,uint32_t>( apex_masked_scale_cuda<at::Half,float,uint32_t>(
static_cast<at::Half const*>(output_grads.data_ptr()), static_cast<at::Half const*>(output_grads.data_ptr()),
...@@ -355,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -355,7 +355,7 @@ std::vector<torch::Tensor> bwd_cuda(
(1.0 / (1.0 - dropout_prob))); (1.0 / (1.0 - dropout_prob)));
// Output Linear Dgrad // Output Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -381,7 +381,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Output Linear Wgrad // Output Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -518,7 +518,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -518,7 +518,7 @@ std::vector<torch::Tensor> bwd_cuda(
attn_batches); attn_batches);
// Input Linear Dgrad // Input Linear Dgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_N, CUBLAS_OP_N,
embed_dim, embed_dim,
...@@ -545,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -545,7 +545,7 @@ std::vector<torch::Tensor> bwd_cuda(
flags)); flags));
// Input Linear Wgrad // Input Linear Wgrad
THCublasCheck(rocblas_gemm_ex(handle, TORCH_CUDABLAS_CHECK(rocblas_gemm_ex(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
CUBLAS_OP_T, CUBLAS_OP_T,
embed_dim, embed_dim,
...@@ -588,6 +588,7 @@ std::vector<torch::Tensor> bwd_cuda( ...@@ -588,6 +588,7 @@ std::vector<torch::Tensor> bwd_cuda(
static_cast<half*>(lyr_nrm_beta_grads.data_ptr()) static_cast<half*>(lyr_nrm_beta_grads.data_ptr())
); );
//TORCH_CUDABLAS_CHECK(cublasSetMathMode(handle, CUBLAS_DEFAULT_MATH));
return { return {
input_grads, input_grads,
......
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
//#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include "THC/THC.h"
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// symbol to be automatically resolved by PyTorch libs // symbol to be automatically resolved by PyTorch libs
extern THCState *state; extern THCState *state;
...@@ -28,7 +29,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) { ...@@ -28,7 +29,7 @@ cublasOperation_t convertTransToCublasOperation(char trans) {
else if (trans == 'n') return CUBLAS_OP_N; else if (trans == 'n') return CUBLAS_OP_N;
else if (trans == 'c') return CUBLAS_OP_C; else if (trans == 'c') return CUBLAS_OP_C;
else { else {
THError("trans must be one of: t, n, c"); AT_ERROR("trans must be one of: t, n, c");
return CUBLAS_OP_T; return CUBLAS_OP_T;
} }
} }
...@@ -44,7 +45,8 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m ...@@ -44,7 +45,8 @@ void RocblasStridedBatchedGemm(THCState *state, char transa, char transb, long m
cublasSetStream(handle, stream); cublasSetStream(handle, stream);
float fAlpha = alpha; float fAlpha = alpha;
float fBeta = beta; float fBeta = beta;
THCublasCheck(rocblas_gemm_strided_batched_ex(handle, //THCublasCheck(cublasSetMathMode(handle, CUBLAS_TENSOR_OP_MATH));
TORCH_CUDABLAS_CHECK(rocblas_gemm_strided_batched_ex(handle,
opa, opb, (int)m, (int)n, (int)k, opa, opb, (int)m, (int)n, (int)k,
(void*)&fAlpha, a, a_type, (int)lda, strideA, (void*)&fAlpha, a, a_type, (int)lda, strideA,
b, b_type, (int)ldb, strideB, b, b_type, (int)ldb, strideB,
...@@ -112,7 +114,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long ...@@ -112,7 +114,7 @@ void HgemmStridedBatched(THCState *state, char transa, char transb, long m, long
if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) )
{ {
THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" AT_ERROR("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount"
"with the bound [val] <= %d", INT_MAX); "with the bound [val] <= %d", INT_MAX);
} }
......
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <stdio.h> #include <stdio.h>
#include <cmath> #include <cmath>
#include "ATen/ATen.h"
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/detail/IndexUtils.cuh"
#include "ATen/TensorUtils.h" #include "ATen/TensorUtils.h"
// #include "ATen/Type.h" // #include "ATen/Type.h"
#include "ATen/AccumulateType.h" #include "ATen/AccumulateType.h"
#include <THC/THCGeneral.h>
#include "multi_tensor_apply.cuh" #include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512 #define BLOCK_SIZE 512
......
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h> #include <ATen/cuda/Exceptions.h>
#include <THC/THCGeneral.h>
// Another possibility: // Another possibility:
// #include <torch/all.h> // #include <torch/all.h>
......
#include <torch/extension.h>
#include <cuda.h> #include <cuda.h>
#include <curand_kernel.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <c10/macros/Macros.h>
#include <THC/THC.h> #include <torch/extension.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/CUDAGeneratorImpl.h> #include <ATen/CUDAGeneratorImpl.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAGraphsUtils.cuh> #include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <curand_kernel.h> #include <c10/macros/Macros.h>
#include "philox.h" #include "philox.h"
// Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width. // Warp reduce kernels to reduce N groups of data into N numbers, where N = warpSize / width.
......
#include <torch/extension.h> #include <vector>
#include <cuda.h> #include <cuda.h>
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <vector>
#include <torch/extension.h>
#include <ATen/ATen.h> #include <ATen/ATen.h>
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <THC/THC.h>
#include <ATen/cuda/CUDAContext.h> #include <ATen/cuda/CUDAContext.h>
template<typename scalar_t> template<typename scalar_t>
......
...@@ -76,9 +76,6 @@ ...@@ -76,9 +76,6 @@
#include <ATen/AccumulateType.h> #include <ATen/AccumulateType.h>
#include <ATen/cuda/NumericLimits.cuh> #include <ATen/cuda/NumericLimits.cuh>
#include <THC/THC.h>
#include <THC/THCGeneral.h>
#include "type_shim.h" #include "type_shim.h"
#include "compat.h" #include "compat.h"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment