Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
c7b34549
Commit
c7b34549
authored
Apr 09, 2020
by
Thor Johnsen
Browse files
Add no-flattening e5m2-allgather option
parent
cd206434
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
157 additions
and
10 deletions
+157
-10
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+3
-1
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+113
-2
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+15
-7
csrc/type_shim.h
csrc/type_shim.h
+26
-0
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
c7b34549
...
@@ -10,6 +10,7 @@ void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::v
...
@@ -10,6 +10,7 @@ void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::v
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
unpack_e5m2_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...
@@ -64,7 +65,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -64,7 +65,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2"
,
&
unpack_e5m2
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2"
,
&
unpack_e5m2
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"unpack_e5m2_mt"
,
&
unpack_e5m2_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
}
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
c7b34549
...
@@ -126,6 +126,36 @@ __global__ void strided_check_finite_cuda_kernel(
...
@@ -126,6 +126,36 @@ __global__ void strided_check_finite_cuda_kernel(
}
}
}
}
}
}
template
<
>
__global__
void
strided_check_finite_cuda_kernel
(
volatile
int
*
noop_gmem
,
uint8_t
*
__restrict__
p_copy
,
const
size_t
tsize
,
int
stride
,
int
clear_overflow_first
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
)
*
stride
;
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
*
stride
;
if
(
clear_overflow_first
)
{
if
(
i
==
0
)
{
*
noop_gmem
=
0
;
}
__syncthreads
();
}
for
(
int
j
=
i
;
j
<
tsize
;
j
+=
totThreads
)
{
at
::
Half
pi
;
convert
(
p_copy
[
j
],
pi
);
if
(
!
isfinite
(
pi
))
{
*
noop_gmem
=
1
;
}
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
adam_cuda_kernel
(
__global__
void
adam_cuda_kernel
(
...
@@ -337,6 +367,65 @@ __global__ void adam_undo_cuda_kernel(
...
@@ -337,6 +367,65 @@ __global__ void adam_undo_cuda_kernel(
}
}
}
}
template
<
int
DEPTH
,
typename
FROM_T
,
typename
TO_T
>
struct
UnpackE5M2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
DEPTH
>&
tl
)
{
if
(
*
noop_gmem
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
FROM_T
*
p_in
=
(
FROM_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
p_in
+=
chunk_idx
*
chunk_size
;
TO_T
*
p_out
=
(
TO_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p_out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
int
dim
=
chunk_size
<
n
?
chunk_size
:
n
;
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
FROM_T
(
0
);
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
if
(
!
isfinite
(
po
[
ii
]))
{
overflow
=
true
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
j
<
dim
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
if
(
overflow
)
{
*
noop_gmem
=
1
;
}
}
};
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamFunctor
struct
AdamFunctor
{
{
...
@@ -533,7 +622,7 @@ void fused_strided_check_finite(
...
@@ -533,7 +622,7 @@ void fused_strided_check_finite(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_
AND_
HALF
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
DISPATCH_FLOAT_HALF
_AND_BYTE
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
noop
.
DATA_PTR
<
int
>
(),
noop
.
DATA_PTR
<
int
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
...
@@ -669,6 +758,28 @@ void unpack_e5m2_cuda(
...
@@ -669,6 +758,28 @@ void unpack_e5m2_cuda(
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
unpack_e5m2_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
// p_in, p_out
{
//Constants
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
2
,
"expected tensor lists of size 2"
);
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"unpack_e5m2_cuda_mt_kernel"
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
UnpackE5M2Functor
<
2
,
uint8_t
,
scalar_t_0
>
());
);
THCudaCheck
(
cudaGetLastError
());
}
void
fused_adam_undo_cuda
(
void
fused_adam_undo_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
m
,
...
...
apex/contrib/optimizers/distributed_fused_adam.py
View file @
c7b34549
import
math
import
math
import
torch
import
torch
import
importlib
import
importlib
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
...
@@ -559,17 +560,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -559,17 +560,15 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow
# Check for overflow
# Store state for loss scaler calculation
# Store state for loss scaler calculation
if
self
.
_e5m2_allgather
:
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
new_params
=
torch
.
empty_like
(
self
.
_flat_grads
)
fused_adam_cuda
.
unpack_e5m2
(
self
.
_new_params
,
new_params
)
else
:
new_params
=
self
.
_new_params
self
.
strided_check_finite
(
new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
self
.
peek_overflow
:
if
self
.
peek_overflow
:
print
(
"Reverting step"
)
print
(
"Reverting step"
)
self
.
revert_step
()
self
.
revert_step
()
else
:
else
:
# Copy self._new_params to model params
# Copy self._new_params to model params
if
self
.
_e5m2_allgather
:
p_in
=
[]
p_out
=
[]
with
torch
.
no_grad
():
with
torch
.
no_grad
():
param_i
=
0
param_i
=
0
for
group
in
self
.
param_groups
:
for
group
in
self
.
param_groups
:
...
@@ -582,8 +581,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -582,8 +581,17 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
nels
=
p
.
numel
()
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
p
.
set_
(
new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
if
self
.
_e5m2_allgather
:
p_in
.
append
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p_out
.
append
(
p
)
else
:
p
.
set_
(
self
.
_new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
param_i
+=
1
param_i
+=
1
if
self
.
_e5m2_allgather
:
multi_tensor_applier
(
fused_adam_cuda
.
unpack_e5m2_mt
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
self
.
_new_params
=
None
self
.
_new_params
=
None
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_blk_st
[
0
])
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_blk_st
[
0
])
...
...
csrc/type_shim.h
View file @
c7b34549
...
@@ -34,6 +34,32 @@
...
@@ -34,6 +34,32 @@
}
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Byte: \
{ \
using scalar_t_##LEVEL = uint8_t; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
switch(TYPE) \
{ \
{ \
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment