Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
e57f5d0e
Commit
e57f5d0e
authored
Apr 09, 2019
by
Michael Carilli
Browse files
Simple cut of the kernel in place
parent
03100f46
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
60 additions
and
1 deletion
+60
-1
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+7
-0
csrc/type_shim.h
csrc/type_shim.h
+51
-0
setup.py
setup.py
+2
-1
No files found.
csrc/amp_C_frontend.cpp
View file @
e57f5d0e
...
@@ -14,9 +14,16 @@ void multi_tensor_axpby_cuda(
...
@@ -14,9 +14,16 @@ void multi_tensor_axpby_cuda(
float
b
,
float
b
,
int
arg_to_check
);
int
arg_to_check
);
at
::
Tensor
multi_tensor_l2norm_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
);
"Fused overflow check + scale for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_axpby"
,
&
multi_tensor_axpby_cuda
,
m
.
def
(
"multi_tensor_axpby"
,
&
multi_tensor_axpby_cuda
,
"out = a*x + b*y for a list of contiguous tensors"
);
"out = a*x + b*y for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_l2norm"
,
&
multi_tensor_l2norm_cuda
,
"Computes L2 norm for a list of contiguous tensors"
);
}
}
csrc/type_shim.h
View file @
e57f5d0e
...
@@ -31,3 +31,54 @@ struct TypeShim
...
@@ -31,3 +31,54 @@ struct TypeShim
default: \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
}
template
<
typename
T
,
typename
ReduceOp
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
T
val
,
int
lanes
,
bool
share_result
)
// lanes is intended to be <= 32.
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
{
x
[
tid
]
=
val
;
__syncthreads
();
}
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
__syncthreads
();
}
T
final
;
if
(
tid
<
32
)
{
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
final
=
val
;
// __SYNCWARP();
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
if
(
share_result
)
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads
();
}
return
final
;
}
setup.py
View file @
e57f5d0e
...
@@ -71,7 +71,8 @@ if "--cuda_ext" in sys.argv:
...
@@ -71,7 +71,8 @@ if "--cuda_ext" in sys.argv:
CUDAExtension
(
name
=
'amp_C'
,
CUDAExtension
(
name
=
'amp_C'
,
sources
=
[
'csrc/amp_C_frontend.cpp'
,
sources
=
[
'csrc/amp_C_frontend.cpp'
,
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_scale_kernel.cu'
,
'csrc/multi_tensor_axpby_kernel.cu'
],
'csrc/multi_tensor_axpby_kernel.cu'
,
'csrc/multi_tensor_l2norm_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:[
'-lineinfo'
,
'nvcc'
:[
'-lineinfo'
,
'-O3'
,
'-O3'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment