Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
4a01ff26
Commit
4a01ff26
authored
Apr 16, 2020
by
Thor Johnsen
Browse files
Partial move towards syncfree optimizer
parent
2622d7f1
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
224 additions
and
178 deletions
+224
-178
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
+19
-16
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
+197
-160
apex/contrib/optimizers/distributed_fused_adam.py
apex/contrib/optimizers/distributed_fused_adam.py
+8
-2
No files found.
apex/contrib/csrc/optimizers/fused_adam_cuda.cpp
View file @
4a01ff26
#include <torch/extension.h>
#include <torch/extension.h>
// CUDA forward declaration
// CUDA forward declaration
void
fused_strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_strided_check_finite
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
clear_overflow_first
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_
maybe_
adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
fused_
maybe_
adam_undo_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
);
void
unpack_e5m2_cuda
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
);
void
unpack_e5m2_cuda_mt
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
void
maybe_cast_cuda_mt
(
int
chunk_size
,
at
::
Tensor
overflow_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
);
void
update_step_and_loss_scaler_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
step_and_loss_scaler
);
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
...
@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
...
@@ -18,13 +20,13 @@ void unpack_e5m2_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::
// C++ interface
// C++ interface
void
strided_check_finite
(
void
strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
stride
,
int
clear_overflow_first
int
clear_overflow_first
)
{
)
{
CHECK_INPUT
(
p_copy
);
CHECK_INPUT
(
p_copy
);
fused_strided_check_finite
(
noop
,
p_copy
,
stride
,
clear_overflow_first
);
fused_strided_check_finite
(
overflow_flag
,
p_copy
,
stride
,
clear_overflow_first
);
}
}
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
void
adam
(
at
::
Tensor
&
p
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
p
);
...
@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
...
@@ -40,7 +42,7 @@ void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, a
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
fused_adam_cuda
(
p
,
p_copy
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
}
void
adam_undo
(
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
void
maybe_
adam_undo
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
g
,
float
lr
,
float
beta1
,
float
beta2
,
float
eps
,
float
grad_scale
,
int
step
,
int
mode
,
int
bias_correction
,
float
decay
)
{
CHECK_INPUT
(
p
);
CHECK_INPUT
(
p
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
m
);
CHECK_INPUT
(
v
);
CHECK_INPUT
(
v
);
...
@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
...
@@ -50,23 +52,24 @@ void adam_undo(at::Tensor & p, at::Tensor & m, at::Tensor & v, at::Tensor & g, f
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
v
.
numel
()
==
num_elem
,
"number of elements in v and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
AT_ASSERTM
(
g
.
numel
()
==
num_elem
,
"number of elements in g and p tensors should be equal"
);
fused_adam_undo_cuda
(
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
fused_
maybe_
adam_undo_cuda
(
overflow_flag
,
p
,
m
,
v
,
g
,
lr
,
beta1
,
beta2
,
eps
,
grad_scale
,
step
,
mode
,
bias_correction
,
decay
);
}
}
void
unpack_e5m2
(
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
void
maybe_cast
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
{
CHECK_INPUT
(
p_in
);
CHECK_INPUT
(
p_in
);
CHECK_INPUT
(
p_out
);
CHECK_INPUT
(
p_out
);
int64_t
num_elem
=
p_in
.
numel
();
int64_t
num_elem
=
p_in
.
numel
();
AT_ASSERTM
(
p_out
.
numel
()
==
num_elem
,
"number of elements in p_in and p_out should be equal"
);
AT_ASSERTM
(
p_out
.
numel
()
==
num_elem
,
"number of elements in p_in and p_out should be equal"
);
unpack_e5m2_cuda
(
p_in
,
p_out
);
maybe_cast_cuda
(
overflow_flag
,
p_in
,
p_out
);
}
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"strided_check_finite"
,
&
strided_check_finite
,
"Strided finite check."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam"
,
&
adam
,
"Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo"
,
&
adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_mt"
,
&
fused_adam_cuda_mt
,
"Multi tensor Adam optimized CUDA implementation."
);
m
.
def
(
"adam_undo_mt"
,
&
fused_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"maybe_adam_undo"
,
&
maybe_adam_undo
,
"Undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2"
,
&
unpack_e5m2
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_adam_undo_mt"
,
&
fused_maybe_adam_undo_cuda_mt
,
"Multi tensor undo function for Adam optimized CUDA implementation."
);
m
.
def
(
"unpack_e5m2_mt"
,
&
unpack_e5m2_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_cast"
,
&
maybe_cast
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"maybe_cast_mt"
,
&
maybe_cast_cuda_mt
,
"Unpack byte tensor containing e5m2 floats."
);
m
.
def
(
"update_step_and_loss_scaler"
,
&
update_step_and_loss_scaler_cuda
,
"Update step and loss scaler."
);
}
}
apex/contrib/csrc/optimizers/fused_adam_cuda_kernel.cu
View file @
4a01ff26
...
@@ -157,6 +157,51 @@ __global__ void strided_check_finite_cuda_kernel(
...
@@ -157,6 +157,51 @@ __global__ void strided_check_finite_cuda_kernel(
}
}
}
}
template
<
typename
FROM_T
,
typename
TO_T
>
__global__
void
maybe_cast_kernel
(
volatile
int
*
overflow_flag
,
const
FROM_T
*
p_in
,
TO_T
*
p_out
,
const
size_t
tsize
)
{
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
0
;
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
}
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
template
<
typename
T
,
typename
GRAD_T
,
typename
REDU_T
>
__global__
void
adam_cuda_kernel
(
__global__
void
adam_cuda_kernel
(
T
*
__restrict__
p
,
T
*
__restrict__
p
,
...
@@ -243,58 +288,9 @@ __global__ void adam_cuda_kernel(
...
@@ -243,58 +288,9 @@ __global__ void adam_cuda_kernel(
}
}
}
}
template
<
typename
GRAD_T
>
__global__
void
unpack_e5m2_kernel
(
const
uint8_t
*
p_in
,
GRAD_T
*
p_out
,
const
size_t
tsize
)
{
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadIdInBlock
=
threadIdx
.
y
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
i
=
(
blockId
*
threadsPerBlock
+
threadIdInBlock
);
const
int
totThreads
=
gridDim
.
x
*
gridDim
.
y
*
threadsPerBlock
;
uint8_t
pi
[
ILP
];
GRAD_T
po
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
tsize
;
j_start
+=
totThreads
*
ILP
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
pi
[
ii
]
=
0
;
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
pi
[
ii
]
=
p_in
[
j
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
if
(
!
isfinite
(
po
[
ii
]))
{
overflow
=
true
;
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
j
=
j_start
+
i
+
totThreads
*
ii
;
if
(
j
<
tsize
)
{
p_out
[
j
]
=
po
[
ii
];
}
}
}
if
(
overflow
)
{
p_out
[
0
]
=
INFINITY
;
}
}
template
<
typename
T
,
typename
GRAD_T
>
template
<
typename
T
,
typename
GRAD_T
>
__global__
void
adam_undo_cuda_kernel
(
__global__
void
maybe_adam_undo_cuda_kernel
(
volatile
int
*
overflow_flag
,
T
*
__restrict__
p
,
T
*
__restrict__
p
,
T
*
__restrict__
m
,
T
*
__restrict__
m
,
T
*
__restrict__
v
,
T
*
__restrict__
v
,
...
@@ -308,6 +304,9 @@ __global__ void adam_undo_cuda_kernel(
...
@@ -308,6 +304,9 @@ __global__ void adam_undo_cuda_kernel(
adamMode_t
mode
,
adamMode_t
mode
,
const
float
decay
)
const
float
decay
)
{
{
// NB! Skip undo kernel when overflow flag is NOT set
if
(
overflow_flag
&&
*
overflow_flag
==
0
)
return
;
//Assuming 2D grids and 2D blocks
//Assuming 2D grids and 2D blocks
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
blockId
=
gridDim
.
x
*
blockIdx
.
y
+
blockIdx
.
x
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
const
int
threadsPerBlock
=
blockDim
.
x
*
blockDim
.
y
;
...
@@ -367,15 +366,46 @@ __global__ void adam_undo_cuda_kernel(
...
@@ -367,15 +366,46 @@ __global__ void adam_undo_cuda_kernel(
}
}
}
}
__global__
void
update_step_and_loss_scaler_kernel
(
volatile
int
*
overflow_flag
,
double
*
__restrict__
step_and_loss_scaler_vec
)
{
// 0 : step
// 1 : iter
// 2 : loss_scale
// 3 : last_overflow_iter
// 4 : scale_factor
// 5 : scale_window
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
&&
threadIdx
.
z
==
0
&&
blockIdx
.
x
==
0
&&
blockIdx
.
y
==
0
&&
blockIdx
.
z
==
0
)
{
double
loss_scale
=
step_and_loss_scaler_vec
[
2
];
double
scale_factor
=
step_and_loss_scaler_vec
[
4
];
int
iter
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
1
]);
int
last_overflow_iter
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
3
]);
if
(
*
overflow_flag
==
0
)
{
// increase step
step_and_loss_scaler_vec
[
0
]
+=
1.0
;
// maybe increase loss scaler
int
scale_window
=
static_cast
<
int
>
(
step_and_loss_scaler_vec
[
5
]);
if
(((
iter
-
last_overflow_iter
)
%
scale_window
)
==
0
)
{
step_and_loss_scaler_vec
[
2
]
=
loss_scale
*
scale_factor
;
}
}
else
{
step_and_loss_scaler_vec
[
2
]
=
loss_scale
/
scale_factor
;
step_and_loss_scaler_vec
[
3
]
=
static_cast
<
double
>
(
iter
);
}
step_and_loss_scaler_vec
[
1
]
+=
1.0
;
}
}
template
<
int
DEPTH
,
typename
FROM_T
,
typename
TO_T
>
template
<
int
DEPTH
,
typename
FROM_T
,
typename
TO_T
>
struct
UnpackE5M2
Functor
struct
MaybeCast
Functor
{
{
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
)
TensorListMetadata
<
DEPTH
>&
tl
)
{
{
if
(
*
noop_gmem
!=
0
)
return
;
if
(
overflow_flag
&&
*
overflow_flag
!=
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
...
@@ -392,7 +422,6 @@ struct UnpackE5M2Functor
...
@@ -392,7 +422,6 @@ struct UnpackE5M2Functor
FROM_T
pi
[
ILP
];
FROM_T
pi
[
ILP
];
TO_T
po
[
ILP
];
TO_T
po
[
ILP
];
bool
overflow
=
false
;
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
for
(
int
j_start
=
0
;
j_start
<
dim
;
j_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
...
@@ -406,9 +435,6 @@ struct UnpackE5M2Functor
...
@@ -406,9 +435,6 @@ struct UnpackE5M2Functor
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
convert
(
pi
[
ii
],
po
[
ii
]);
convert
(
pi
[
ii
],
po
[
ii
]);
if
(
!
isfinite
(
po
[
ii
]))
{
overflow
=
true
;
}
}
}
#pragma unroll
#pragma unroll
...
@@ -419,10 +445,6 @@ struct UnpackE5M2Functor
...
@@ -419,10 +445,6 @@ struct UnpackE5M2Functor
}
}
}
}
}
}
if
(
overflow
)
{
*
noop_gmem
=
1
;
}
}
}
};
};
...
@@ -431,7 +453,7 @@ struct AdamFunctor
...
@@ -431,7 +453,7 @@ struct AdamFunctor
{
{
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b1
,
const
float
b2
,
const
float
b2
,
...
@@ -516,17 +538,17 @@ struct AdamFunctor
...
@@ -516,17 +538,17 @@ struct AdamFunctor
}
}
if
(
overflow
)
{
if
(
overflow
)
{
*
noop_gmem
=
1
;
*
overflow_flag
=
1
;
}
}
}
}
};
};
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
template
<
int
DEPTH
,
typename
T
,
typename
GRAD_T
>
struct
AdamUndoFunctor
struct
Maybe
AdamUndoFunctor
{
{
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
overflow_flag
,
TensorListMetadata
<
DEPTH
>&
tl
,
TensorListMetadata
<
DEPTH
>&
tl
,
const
float
b1
,
const
float
b1
,
const
float
b2
,
const
float
b2
,
...
@@ -536,6 +558,9 @@ struct AdamUndoFunctor
...
@@ -536,6 +558,9 @@ struct AdamUndoFunctor
adamMode_t
mode
,
adamMode_t
mode
,
const
float
decay
)
const
float
decay
)
{
{
// Skip Adam undo when overflow flag is NOT set
if
(
overflow_flag
&&
*
overflow_flag
==
0
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
.
sizes
[
tensor_loc
];
...
@@ -606,7 +631,7 @@ struct AdamUndoFunctor
...
@@ -606,7 +631,7 @@ struct AdamUndoFunctor
};
};
void
fused_strided_check_finite
(
void
fused_strided_check_finite
(
at
::
Tensor
&
noop
,
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_copy
,
at
::
Tensor
&
p_copy
,
int
stride
,
int
stride
,
int
clear_overflow_first
)
int
clear_overflow_first
)
...
@@ -624,7 +649,7 @@ void fused_strided_check_finite(
...
@@ -624,7 +649,7 @@ void fused_strided_check_finite(
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_copy
.
scalar_type
(),
0
,
"check_finite_cuda_kernel"
,
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
strided_check_finite_cuda_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
noop
.
DATA_PTR
<
int
>
(),
overflow_flag
.
DATA_PTR
<
int
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
p_copy
.
DATA_PTR
<
scalar_t_0
>
(),
tsize
,
tsize
,
stride
,
stride
,
...
@@ -734,7 +759,8 @@ void fused_adam_cuda(
...
@@ -734,7 +759,8 @@ void fused_adam_cuda(
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
unpack_e5m2_cuda
(
void
maybe_cast_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_in
,
at
::
Tensor
&
p_out
)
at
::
Tensor
&
p_out
)
{
{
...
@@ -747,20 +773,19 @@ void unpack_e5m2_cuda(
...
@@ -747,20 +773,19 @@ void unpack_e5m2_cuda(
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_in
),
"parameter tensor is too large to be indexed with int32"
);
AT_ASSERTM
(
at
::
cuda
::
detail
::
canUse32BitIndexMath
(
p_in
),
"parameter tensor is too large to be indexed with int32"
);
//Constants
//Constants
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_ASSERTM
(
p_in
.
scalar_type
()
==
at
::
ScalarType
::
Byte
,
"expected parameter to be of byte type"
);
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_in
.
scalar_type
(),
0
,
"maybe_cast_cuda"
AT_ASSERTM
(
p_out
.
scalar_type
()
==
at
::
ScalarType
::
Half
,
"expected parameter to be of half type"
);
DISPATCH_FLOAT_HALF_AND_BYTE
(
p_out
.
scalar_type
(),
1
,
"maybe_cast_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
p_out
.
scalar_type
(),
0
,
"unpack_e5m2"
,
maybe_cast_kernel
<
scalar_t_0
,
scalar_t_1
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
unpack_e5m2_kernel
<
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p_in
.
DATA_PTR
<
uint8_t
>
(),
p_in
.
DATA_PTR
<
scalar_t_0
>
(),
p_out
.
DATA_PTR
<
scalar_t_0
>
(),
p_out
.
DATA_PTR
<
scalar_t_1
>
(),
tsize
);
tsize
);
))
);
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
unpack_e5m2
_cuda_mt
(
void
maybe_cast
_cuda_mt
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop
_flag
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
// p_in, p_out
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
// p_in, p_out
{
{
//Constants
//Constants
...
@@ -769,18 +794,31 @@ void unpack_e5m2_cuda_mt(
...
@@ -769,18 +794,31 @@ void unpack_e5m2_cuda_mt(
size_t
tl_sz
=
tensor_lists
.
size
();
size_t
tl_sz
=
tensor_lists
.
size
();
AT_ASSERTM
(
tl_sz
==
2
,
"expected tensor lists of size 2"
);
AT_ASSERTM
(
tl_sz
==
2
,
"expected tensor lists of size 2"
);
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"unpack_e5m2_cuda_mt_kernel"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"maybe_cast_cuda_mt_kernel"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"maybe_cast_cuda_mt_kernel"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
UnpackE5M2Functor
<
2
,
uint8_t
,
scalar_t_0
>
());
MaybeCastFunctor
<
2
,
scalar_t_0
,
scalar_t_1
>
());
))
);
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
fused_adam_undo_cuda
(
void
update_step_and_loss_scaler_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
step_and_loss_scaler
)
{
AT_ASSERTM
(
step_and_loss_scaler
.
numel
()
==
6
,
"step_and_loss_scaler must have 6 elements"
);
AT_ASSERTM
(
step_and_loss_scaler
.
scalar_type
()
==
at
::
ScalarType
::
Double
,
"expected step_and_loss_scaler to be a double tensor"
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
update_step_and_loss_scaler_kernel
<<<
1
,
1
,
0
,
stream
>>>
(
overflow_flag
.
DATA_PTR
<
int
>
(),
step_and_loss_scaler
.
DATA_PTR
<
double
>
());
}
void
fused_maybe_adam_undo_cuda
(
at
::
Tensor
&
overflow_flag
,
at
::
Tensor
&
p
,
at
::
Tensor
&
p
,
at
::
Tensor
&
m
,
at
::
Tensor
&
m
,
at
::
Tensor
&
v
,
at
::
Tensor
&
v
,
...
@@ -795,8 +833,6 @@ void fused_adam_undo_cuda(
...
@@ -795,8 +833,6 @@ void fused_adam_undo_cuda(
int
bias_correction
,
int
bias_correction
,
float
decay
)
float
decay
)
{
{
// using namespace at;
//Get tensor size
//Get tensor size
int
tsize
=
p
.
numel
();
int
tsize
=
p
.
numel
();
//Determine #threads and #blocks
//Determine #threads and #blocks
...
@@ -816,13 +852,14 @@ void fused_adam_undo_cuda(
...
@@ -816,13 +852,14 @@ void fused_adam_undo_cuda(
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
if
(
g
.
scalar_type
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
p
.
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
adam_undo_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
maybe_adam_undo_cuda_kernel
<
accscalar_t
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p
.
DATA_PTR
<
accscalar_t
>
(),
p
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
m
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
v
.
DATA_PTR
<
accscalar_t
>
(),
...
@@ -839,7 +876,8 @@ void fused_adam_undo_cuda(
...
@@ -839,7 +876,8 @@ void fused_adam_undo_cuda(
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
DISPATCH_DOUBLE_AND_FLOAT
(
g
.
scalar_type
(),
0
,
"adam_cuda_kernel"
,
adam_undo_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
maybe_adam_undo_cuda_kernel
<
scalar_t_0
,
scalar_t_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
overflow_flag
.
numel
()
?
overflow_flag
.
DATA_PTR
<
int
>
()
:
NULL
,
p
.
DATA_PTR
<
scalar_t_0
>
(),
p
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
m
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
v
.
DATA_PTR
<
scalar_t_0
>
(),
...
@@ -855,12 +893,11 @@ void fused_adam_undo_cuda(
...
@@ -855,12 +893,11 @@ void fused_adam_undo_cuda(
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
fused_adam_cuda_mt
(
void
fused_adam_cuda_mt
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop
_flag
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
lr
,
float
beta1
,
float
beta1
,
...
@@ -897,7 +934,7 @@ void fused_adam_cuda_mt(
...
@@ -897,7 +934,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
AdamFunctor
<
5
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta1
,
...
@@ -914,7 +951,7 @@ void fused_adam_cuda_mt(
...
@@ -914,7 +951,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
AdamFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta1
,
...
@@ -932,7 +969,7 @@ void fused_adam_cuda_mt(
...
@@ -932,7 +969,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
AdamFunctor
<
5
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta1
,
...
@@ -948,7 +985,7 @@ void fused_adam_cuda_mt(
...
@@ -948,7 +985,7 @@ void fused_adam_cuda_mt(
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
AdamFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta1
,
...
@@ -964,9 +1001,9 @@ void fused_adam_cuda_mt(
...
@@ -964,9 +1001,9 @@ void fused_adam_cuda_mt(
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
}
}
void
fused_adam_undo_cuda_mt
(
void
fused_
maybe_
adam_undo_cuda_mt
(
int
chunk_size
,
int
chunk_size
,
at
::
Tensor
noop
_flag
,
at
::
Tensor
overflow
_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
// p, m, v, g, p_copy
float
lr
,
float
lr
,
float
beta1
,
float
beta1
,
...
@@ -997,14 +1034,14 @@ void fused_adam_undo_cuda_mt(
...
@@ -997,14 +1034,14 @@ void fused_adam_undo_cuda_mt(
//alher values should be fp32 for half gradients
//alher values should be fp32 for half gradients
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
tensor_lists
[
0
][
0
].
scalar_type
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dich is done on the gradient type
//dich is done on the gradient type
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_undo_cuda_mt_kernel"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"
maybe_
adam_undo_cuda_mt_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamUndoFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
Maybe
AdamUndoFunctor
<
4
,
accscalar_t
,
scalar_t_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -1014,13 +1051,13 @@ void fused_adam_undo_cuda_mt(
...
@@ -1014,13 +1051,13 @@ void fused_adam_undo_cuda_mt(
decay
);
decay
);
);
);
}
else
{
}
else
{
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"adam_undo_cuda_mt_kernel"
,
DISPATCH_DOUBLE_AND_FLOAT
(
tensor_lists
[
3
][
0
].
scalar_type
(),
0
,
"
maybe_
adam_undo_cuda_mt_kernel"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
noop
_flag
,
overflow
_flag
,
tensor_lists
,
tensor_lists
,
AdamUndoFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
Maybe
AdamUndoFunctor
<
4
,
scalar_t_0
,
scalar_t_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
...
apex/contrib/optimizers/distributed_fused_adam.py
View file @
4a01ff26
...
@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -154,6 +154,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
torch
.
distributed
.
get_rank
()
in
ranks
:
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
...
@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -166,6 +168,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
self
.
_rs_pg
.
append
(
grp
)
self
.
_rs_pg
.
append
(
grp
)
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
if
self
.
_compute_L2_grad_norm
and
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
self
.
_l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
if
self
.
_num_ag_pg
==
0
:
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_pg
=
self
.
_rs_pg
...
@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -180,6 +183,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
if
torch
.
distributed
.
get_rank
()
in
ranks
:
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
if
self
.
_compute_L2_grad_norm
else
None
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
...
@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -452,7 +457,8 @@ class DistributedFusedAdam(torch.optim.Optimizer):
beta1
,
beta2
=
group
[
'betas'
]
beta1
,
beta2
=
group
[
'betas'
]
if
undo
:
if
undo
:
if
self
.
_revert_method
==
1
:
if
self
.
_revert_method
==
1
:
fused_adam_cuda
.
adam_undo
(
fused_adam_cuda
.
maybe_adam_undo
(
torch
.
empty
([
0
]),
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_p
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_m
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
self
.
_fp32_v
[
group_buffer_start
:
group_buffer_end
],
...
@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
...
@@ -576,7 +582,7 @@ class DistributedFusedAdam(torch.optim.Optimizer):
param_i
+=
1
param_i
+=
1
if
self
.
_e5m2_allgather
:
if
self
.
_e5m2_allgather
:
multi_tensor_applier
(
multi_tensor_applier
(
fused_adam_cuda
.
unpack_e5m2
_mt
,
fused_adam_cuda
.
maybe_cast
_mt
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
[
p_in
,
p_out
]);
[
p_in
,
p_out
]);
elif
self
.
_do_not_flatten_model
:
elif
self
.
_do_not_flatten_model
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment