Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
ebb4e88a
Commit
ebb4e88a
authored
Aug 23, 2022
by
hubertlu-tw
Browse files
Enable --focal_loss and --index_mul_2d_cuda extensions on ROCm
parent
40e15362
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
14 deletions
+41
-14
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
+21
-12
apex/contrib/test/run_rocm_extensions.py
apex/contrib/test/run_rocm_extensions.py
+1
-1
setup.py
setup.py
+19
-1
No files found.
apex/contrib/csrc/index_mul_2d/index_mul_2d_cuda_kernel.cu
View file @
ebb4e88a
...
...
@@ -311,16 +311,18 @@ void index_mul_2d_float_foward_cuda(at::Tensor &out,
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
}
else
{
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
...
...
@@ -346,8 +348,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
...
...
@@ -356,8 +359,9 @@ void index_mul_2d_float_backward_cuda(at::Tensor &grad_in1,
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
...
...
@@ -384,8 +388,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const
int
BLOCK_THREADS_DIMX
=
16
;
const
int
BLOCK_THREADS_DIMY
=
16
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_float_dim64
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_grad_float_dim64
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
);
...
...
@@ -393,8 +398,9 @@ void index_mul_2d_float_backward_backward_cuda(at::Tensor &grad_grad_out,
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_float
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_grad_float
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
float
>
(),
grad_in1
.
data_ptr
<
float
>
(),
grad_in2
.
data_ptr
<
float
>
(),
grad_out
.
data_ptr
<
float
>
(),
grad_grad_in1
.
data_ptr
<
float
>
(),
grad_grad_in2
.
data_ptr
<
float
>
(),
in1
.
data_ptr
<
float
>
(),
in2
.
data_ptr
<
float
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
...
...
@@ -418,8 +424,9 @@ void index_mul_2d_half_foward_cuda(at::Tensor &out,
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
...
...
@@ -443,8 +450,9 @@ void index_mul_2d_half_backward_cuda(at::Tensor &grad_in1,
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
}
...
...
@@ -469,8 +477,9 @@ void index_mul_2d_half_backward_backward_cuda(at::Tensor &grad_grad_out,
const
int
BLOCK_THREADS_DIMX
=
32
;
const
int
BLOCK_THREADS_DIMY
=
8
;
const
int
BLOCK_NUMS
=
(
size
+
BLOCK_THREADS_DIMY
-
1
)
/
BLOCK_THREADS_DIMY
;
dim3
threads
(
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
);
index_mul_2d_grad_grad_half
<<<
BLOCK_NUMS
,
{
BLOCK_THREADS_DIMX
,
BLOCK_THREADS_DIMY
,
1
}
,
0
,
stream
>>>
(
index_mul_2d_grad_grad_half
<<<
BLOCK_NUMS
,
threads
,
0
,
stream
>>>
(
grad_grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_in2
.
data_ptr
<
at
::
Half
>
(),
grad_out
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in1
.
data_ptr
<
at
::
Half
>
(),
grad_grad_in2
.
data_ptr
<
at
::
Half
>
(),
in1
.
data_ptr
<
at
::
Half
>
(),
in2
.
data_ptr
<
at
::
Half
>
(),
idx1
.
data_ptr
<
int64_t
>
(),
size
,
fea_dim
);
...
...
apex/contrib/test/run_rocm_extensions.py
View file @
ebb4e88a
...
...
@@ -2,7 +2,7 @@ import unittest
import
sys
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"."
]
# "." for test_label_smoothing.py
test_dirs
=
[
"groupbn"
,
"layer_norm"
,
"multihead_attn"
,
"focal_loss"
,
"index_mul_2d"
,
"."
]
# "." for test_label_smoothing.py
ROCM_BLACKLIST
=
[
"layer_norm"
]
...
...
setup.py
View file @
ebb4e88a
...
...
@@ -307,7 +307,25 @@ if "--xentropy" in sys.argv or "--cuda_ext" in sys.argv:
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:[
'-O3'
]
+
version_dependent_macros
}))
if
"--index_mul_2d"
in
sys
.
argv
:
if
"--focal_loss"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--focal_loss"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--focal_loss"
)
ext_modules
.
append
(
CUDAExtension
(
name
=
'focal_loss_cuda'
,
sources
=
[
'apex/contrib/csrc/focal_loss/focal_loss_cuda.cpp'
,
'apex/contrib/csrc/focal_loss/focal_loss_cuda_kernel.cu'
,
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
]
+
version_dependent_macros
,
'nvcc'
:([
'-O3'
,
'--use_fast_math'
,
'--ftz=false'
]
if
not
IS_ROCM_PYTORCH
else
[
'-O3'
])
+
version_dependent_macros
,
},
)
)
if
"--index_mul_2d"
in
sys
.
argv
or
"--cuda_ext"
in
sys
.
argv
:
if
"--index_mul_2d"
in
sys
.
argv
:
sys
.
argv
.
remove
(
"--index_mul_2d"
)
ext_modules
.
append
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment