Commit ebb4e88a authored by hubertlu-tw's avatar hubertlu-tw
Browse files

Enable --focal_loss and --index_mul_2d_cuda extensions on ROCm

parent 40e15362
......@@ -311,16 +311,18 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out,
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size);
} else {
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_float<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<float>(), in1.data_ptr<float>(), in2.data_ptr<float>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
}
......@@ -346,8 +348,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
......@@ -356,8 +359,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(), grad_out.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
......@@ -384,8 +388,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 16;
const int BLOCK_THREADS_DIMY = 16;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_grad_float_dim64<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size);
......@@ -393,8 +398,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_float<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_grad_float<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<float>(), grad_in1.data_ptr<float>(), grad_in2.data_ptr<float>(),
grad_out.data_ptr<float>(), grad_grad_in1.data_ptr<float>(), grad_grad_in2.data_ptr<float>(),
in1.data_ptr<float>(), in2.data_ptr<float>(), idx1.data_ptr<int64_t>(), size, fea_dim);
......@@ -418,8 +424,9 @@ void index_mul_2d_half_foward_cuda(at::Tensor &out,
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_half<<<BLOCK_NUMS, threads, 0, stream>>>(
out.data_ptr<at::Half>(), in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(),
idx1.data_ptr<int64_t>(), size, fea_dim);
......@@ -443,8 +450,9 @@ void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(), grad_out.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
}
......@@ -469,8 +477,9 @@ void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
const int BLOCK_THREADS_DIMX = 32;
const int BLOCK_THREADS_DIMY = 8;
const int BLOCK_NUMS = (size + BLOCK_THREADS_DIMY - 1) / BLOCK_THREADS_DIMY;
dim3 threads(BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1);
index_mul_2d_grad_grad_half<<<BLOCK_NUMS, {BLOCK_THREADS_DIMX, BLOCK_THREADS_DIMY, 1}, 0, stream>>>(
index_mul_2d_grad_grad_half<<<BLOCK_NUMS, threads, 0, stream>>>(
grad_grad_out.data_ptr<at::Half>(), grad_in1.data_ptr<at::Half>(), grad_in2.data_ptr<at::Half>(),
grad_out.data_ptr<at::Half>(), grad_grad_in1.data_ptr<at::Half>(), grad_grad_in2.data_ptr<at::Half>(),
in1.data_ptr<at::Half>(), in2.data_ptr<at::Half>(), idx1.data_ptr<int64_t>(), size, fea_dim);
......
......@@ -2,7 +2,7 @@ import unittest
import sys
test_dirs = ["groupbn", "layer_norm", "multihead_attn", "."] # "." for test_label_smoothing.py
test_dirs = ["groupbn", "layer_norm", "multihead_attn", "focal_loss", "index_mul_2d", "."] # "." for test_label_smoothing.py
ROCM_BLACKLIST = [
"layer_norm"
]
......
......@@ -307,7 +307,25 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args={'cxx': ['-O3'] + version_dependent_macros,
'nvcc':['-O3'] + version_dependent_macros}))
if "--index_mul_2d" in sys.argv:
if "--focal_loss" in sys.argv or "--cuda_ext" in sys.argv:
if "--focal_loss" in sys.argv:
sys.argv.remove("--focal_loss")
ext_modules.append(
CUDAExtension(
name='focal_loss_cuda',
sources=[
'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp',
'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu',
],
include_dirs=[os.path.join(this_dir, 'csrc')],
extra_compile_args={
'cxx': ['-O3'] + version_dependent_macros,
'nvcc':(['-O3', '--use_fast_math', '--ftz=false'] if not IS_ROCM_PYTORCH else ['-O3']) + version_dependent_macros,
},
)
)
if "--index_mul_2d" in sys.argv or "--cuda_ext" in sys.argv:
if "--index_mul_2d" in sys.argv:
sys.argv.remove("--index_mul_2d")
ext_modules.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment