Commit ebb4e88a authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Enable --focal_loss and --index_mul_2d_cuda extensions on ROCm

parent 40e15362
...@@ -311,16 +311,18 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out, ...@@ -311,16 +311,18 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out,
const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16; const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(), out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size); idx1.data_ptr<int64_t>(), size);
} else { } else {
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_float<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(), out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size, fea_dim); idx1.data_ptr<int64_t>(), size, fea_dim);
} }
...@@ -346,8 +348,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, ...@@ -346,8 +348,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16; const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size); in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
...@@ -356,8 +359,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1, ...@@ -356,8 +359,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim); in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
} }
...@@ -384,8 +388,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, ...@@ -384,8 +388,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 16; const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16; const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size); in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
...@@ -393,8 +398,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out, ...@@ -393,8 +398,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim); in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
...@@ -418,8 +424,9 @@ void index_mul_2d_half_foward_cuda(at::Tensor &out, ...@@ -418,8 +424,9 @@ void index_mul_2d_half_foward_cuda(at::Tensor &out,
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_half<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(),
idx1.data_ptr<int64_t>(), size, fea_dim); idx1.data_ptr<int64_t>(), size, fea_dim);
...@@ -443,8 +450,9 @@ void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1, ...@@ -443,8 +450,9 @@ void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim); in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
} }
...@@ -469,8 +477,9 @@ void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out, ...@@ -469,8 +477,9 @@ void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 32; const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8; const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY; const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>( index_mul_2d_grad_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(),
grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim); in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
......
...@@ -2,7 +2,7 @@ import unittest ...@@ -2,7 +2,7 @@ import unittest
import sys import sys
test_dirs = ["groupbn", "layer_norm", "multihead_attn", "."] # "." for test_label_smoothing.py test_dirs = ["groupbn", "layer_norm", "multihead_attn", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [ ROCM_BLACKLIST = [
"layer_norm" "layer_norm"
] ]
......
...@@ -307,7 +307,25 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -307,7 +307,25 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros, extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros})) 'nvcc':['-O3'] + version_dependent_macros}))
if "--index_mul_2d" in sys.argv: if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv:
if "--focal_loss" in sys.argv:
sys.argv.remove("--focal_loss")
ext_modules.append(
CUDAExtension(
name='focal_loss_cuda',
sources=[
'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp',
'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu',
],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros,
},
)
)
if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv: if "--index_mul_2d" in sys.argv:
sys.argv.remove("--index_mul_2d") sys.argv.remove("--index_mul_2d")
ext_modules.append( ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment