Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
cd206434
Commit
cd206434
authored
Apr 09, 2020
by
Thor Johnsen
Browse files
Add e5m2 allgather option
parent
aa90d31f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
273 additions
and
78 deletions
+273
-78
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+10
-0
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+245
-73
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+18
-5
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
cd206434
...
@@ -9,6 +9,7 @@ void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Te
...
@@ -9,6 +9,7 @@ void fused_adam_undo_cuda(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Te
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...
@@ -50,11 +51,20 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
...
@@ -50,11 +51,20 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
fused_adam_undo_cuda
(
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
fused_adam_undo_cuda
(
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
}
void
unpack_e5m2
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
CHECK_INPUT
(
p_in
);
CHECK_INPUT
(
p_out
);
int64_t
num_elem
=
p_in
.
numel
();
AT_ASSERTM
(
p_out
.
numel
()
==
num_elem
,
"number of elements in p_in and p_out should be equal"
);
unpack_e5m2_cuda
(
p_in
,
p_out
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2"
,
&
unpack_e5m2
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
}
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
cd206434
...
@@ -21,6 +21,82 @@ typedef enum{
...
@@ -21,6 +21,82 @@ typedef enum{
ADAM_MODE_1
=
1
// eps outside square root
ADAM_MODE_1
=
1
// eps outside square root
}
adamMode_t
;
}
adamMode_t
;
template
<
typename
FROM_T
,
typename
TO_T
>
__device__
void
convert
(
const
FROM_T
vi
,
TO_T
&
vo
)
{
vo
=
static_cast
<
TO_T
>
(
vi
);
}
template
<
>
__device__
void
convert
(
const
float
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
vi
;
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
float
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
static_cast
<
float
>
(
t
.
as_half
);
}
template
<
>
__device__
void
convert
(
const
at
::
Half
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
static_cast
<
float
>
(
vi
);
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
at
::
Half
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
t
.
as_half
;
}
template
<
typename
GRAD_T
>
template
<
typename
GRAD_T
>
__global__
void
strided_check_finite_cuda_kernel
(
__global__
void
strided_check_finite_cuda_kernel
(
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
...
@@ -51,10 +127,10 @@ __global__ void strided_check_finite_cuda_kernel(
...
@@ -51,10 +127,10 @@ __global__ void strided_check_finite_cuda_kernel(
}
}
}
}
template
<
typename
T
,
typename
GRAD_T
>
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
adam_cuda_kernel
(
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
T
*
__restrict__
p
,
GRAD
_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
REDU
_T
*
__restrict__
p_copy
,
// For mixed precision training, pass NULL if not needed
T
*
__restrict__
m
,
T
*
__restrict__
m
,
T
*
__restrict__
v
,
T
*
__restrict__
v
,
const
GRAD_T
*
__restrict__
g
,
const
GRAD_T
*
__restrict__
g
,
...
@@ -122,7 +198,9 @@ __global__ void adam_cuda_kernel(
...
@@ -122,7 +198,9 @@ __global__ void adam_cuda_kernel(
m
[
j
]
=
mi
[
ii
];
m
[
j
]
=
mi
[
ii
];
v
[
j
]
=
vi
[
ii
];
v
[
j
]
=
vi
[
ii
];
p
[
j
]
=
pi
[
ii
];
p
[
j
]
=
pi
[
ii
];
if
(
p_copy
!=
NULL
)
p_copy
[
j
]
=
static_cast
<
GRAD_T
>
(
pi
[
ii
]);
if
(
p_copy
!=
NULL
)
{
convert
(
pi
[
ii
],
p_copy
[
j
]);
}
}
}
}
}
}
}
...
@@ -130,11 +208,61 @@ __global__ void adam_cuda_kernel(
...
@@ -130,11 +208,61 @@ __global__ void adam_cuda_kernel(
if
(
p_copy
!=
NULL
)
{
if
(
p_copy
!=
NULL
)
{
__syncthreads
();
__syncthreads
();
if
(
overflow
)
{
if
(
overflow
)
{
p_copy
[
0
]
=
INFINITY
;
convert
(
float
(
INFINITY
),
p_copy
[
0
])
;
}
}
}
}
}
}
template
<
typename
GRAD_T
>
__global__
void
unpack_e5m2_kernel
(
const
uint8_t
*
p_in
,
GRAD_T
*
p_out
,
const
size_t
tsize
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
uint8_t
pi
[
ILP
];
GRAD_T
po
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
0
;
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
if
(
!
isfinite
(
po
[
ii
]))
{
overflow
=
true
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
if
(
overflow
)
{
p_out
[
0
]
=
INFINITY
;
}
}
template
<
typename
T
,
typename
GRAD_T
>
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_undo_cuda_kernel
(
__global__
void
adam_undo_cuda_kernel
(
T
*
__restrict__
p
,
T
*
__restrict__
p
,
...
@@ -453,13 +581,14 @@ void fused_adam_cuda(
...
@@ -453,13 +581,14 @@ void fused_adam_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
if
(
p_copy
.
numel
()
==
0
||
p_copy
.
scalar_type
()
==
g
.
scalar_type
())
{
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
p_copy
.
numel
()
?
p_copy
.
DATA_PTR
<
scalar_t_0
>
()
:
NULL
,
m
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
...
@@ -474,10 +603,30 @@ void fused_adam_cuda(
...
@@ -474,10 +603,30 @@ void fused_adam_cuda(
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
);
);
}
else
{
AT_ASSERTM
(
p_copy
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_e5m2_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t_0
,
uint8_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
accscalar_t
>
(),
p_copy
.
DATA_PTR
<
uint8_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
g
.
DATA_PTR
<
scalar_t_0
>
(),
beta1
,
beta2
,
eps
,
grad_scale
,
step_size
,
tsize
,
(
adamMode_t
)
mode
,
decay
);
);
}
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
scalar_t_0
,
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
DATA_PTR
<
scalar_t_0
>
(),
p
.
DATA_PTR
<
scalar_t_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
...
@@ -494,7 +643,30 @@ void fused_adam_cuda(
...
@@ -494,7 +643,30 @@ void fused_adam_cuda(
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
//Get tensor size
int
tsize
=
p_in
.
numel
();
AT_ASSERTM
(
tsize
==
p_out
.
numel
(),
"p_in.numel() must equal p_out.numel()"
);
//Determine #threads and #blocks
const
int
threadsPerBlock
=
512
;
const
dim3
blocks
((
tsize
+
threadsPerBlock
-
1
)
/
threadsPerBlock
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_in
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_ASSERTM
(
p_in
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
AT_ASSERTM
(
p_out
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"expected parameter to be of half type"
);
DISPATCH_FLOAT_AND_HALF
(
p_out
.
scalar_type
(),
0
,
"unpack_e5m2"
,
unpack_e5m2_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p_in
.
DATA_PTR
<
uint8_t
>
(),
p_out
.
DATA_PTR
<
scalar_t_0
>
(),
tsize
);
);
THCudaCheck
(
cudaGetLastError
());
}
}
void
fused_adam_undo_cuda
(
void
fused_adam_undo_cuda
(
...
...
apex/contrib/optimizers/distributed_fused_adam.py
View file @
cd206434
...
@@ -46,7 +46,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -46,7 +46,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
compute_L2_grad_norm
=
False
,
distributed_weight_update
=
0
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_ag_pg
=
0
,
dwu_num_blk_st
=
1
,
revert_method
=
1
,
flat_mt
=
False
,
dwu_num_chunks
=
4
,
predivide
=
True
,
internal_pipeline
=
False
):
dwu_num_chunks
=
4
,
predivide
=
True
,
internal_pipeline
=
False
,
e5m2_allgather
=
False
):
global
fused_adam_cuda
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
...
@@ -80,6 +81,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -80,6 +81,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_num_chunks
=
dwu_num_chunks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_predivide
=
predivide
self
.
_predivide
=
predivide
self
.
_internal_pipeline
=
internal_pipeline
self
.
_internal_pipeline
=
internal_pipeline
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_full_pipeline
=
full_pipeline
self
.
_full_pipeline
=
full_pipeline
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_compute_L2_grad_norm
=
compute_L2_grad_norm
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_L2_grad_norm
=
torch
.
zeros
([]).
cuda
()
if
self
.
_compute_L2_grad_norm
else
None
...
@@ -306,6 +308,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -306,6 +308,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
with
torch
.
cuda
.
stream
(
self
.
_blk_st
[
block_id
%
len
(
self
.
_blk_st
)]):
if
self
.
_full_pipeline
:
if
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
if
self
.
_new_params
is
None
:
if
self
.
_e5m2_allgather
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
torch
.
uint8
)
else
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
self
.
_pipeline_block
(
block_id
,
self
.
_flat_grads
,
self
.
_new_params
)
else
:
else
:
...
@@ -539,6 +544,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -539,6 +544,9 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
if
self
.
_last_step
or
not
self
.
_overlap_reductions
or
not
self
.
_full_pipeline
:
if
self
.
_new_params
is
None
:
if
self
.
_new_params
is
None
:
if
self
.
_e5m2_allgather
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
,
dtype
=
torch
.
uint8
)
else
:
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
self
.
_new_params
=
torch
.
zeros_like
(
self
.
_flat_grads
)
for
inv_block_id
in
range
(
self
.
_num_blocks
):
for
inv_block_id
in
range
(
self
.
_num_blocks
):
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
block_id
=
self
.
_num_blocks
-
inv_block_id
-
1
...
@@ -551,7 +559,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -551,7 +559,12 @@ class DistributedFusedAdam(torch.optim.Optimizer):
# Check for overflow
# Check for overflow
# Store state for loss scaler calculation
# Store state for loss scaler calculation
self
.
strided_check_finite
(
self
.
_new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
self
.
_e5m2_allgather
:
new_params
=
torch
.
empty_like
(
self
.
_flat_grads
)
fused_adam_cuda
.
unpack_e5m2
(
self
.
_new_params
,
new_params
)
else
:
new_params
=
self
.
_new_params
self
.
strided_check_finite
(
new_params
,
stride
=
self
.
_shard_size
,
start
=
0
,
end
=
self
.
_net_total_param_size
)
if
self
.
peek_overflow
:
if
self
.
peek_overflow
:
print
(
"Reverting step"
)
print
(
"Reverting step"
)
self
.
revert_step
()
self
.
revert_step
()
...
@@ -569,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -569,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
state
[
'step'
]
+=
1
state
[
'step'
]
+=
1
nels
=
p
.
numel
()
nels
=
p
.
numel
()
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
offset
=
self
.
_grads_info
[
param_i
][
'param_offset'
]
p
.
set_
(
self
.
_
new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
p
.
set_
(
new_params
[
offset
:
offset
+
nels
].
view_as
(
p
))
param_i
+=
1
param_i
+=
1
self
.
_new_params
=
None
self
.
_new_params
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment