Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
19892f1d
Commit
19892f1d
authored
May 30, 2020
by
Thor Johnsen
Browse files
Distributed LAMB optimizer
parent
5754fa7a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1173 additions
and
0 deletions
+1173
-0
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
+50
-0
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
...ntrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
+493
-0
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+628
-0
setup.py
setup.py
+2
-0
No files found.
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp
0 → 100644
View file @
19892f1d
#include <torch/extension.h>
void
multi_tensor_lamb_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
const
float
lr
,
const
float
beta1
,
const
float
beta2
,
const
float
epsilon
,
const
int
step
,
const
int
bias_correction
,
const
float
weight_decay
,
const
int
grad_averaging
,
const
int
mode
,
const
float
global_grad_norm
,
const
float
max_grad_norm
,
at
::
optional
<
bool
>
use_nvlamb_python
);
void
multi_tensor_lamb_compute_update_term_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
);
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
per_tensor_decay
,
bool
use_nvlamb
);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"multi_tensor_lamb_compute_update_term"
,
&
multi_tensor_lamb_compute_update_term_cuda
,
"Computes update term for LAMB optimizer"
);
m
.
def
(
"multi_tensor_lamb_update_weights"
,
&
multi_tensor_lamb_update_weights_cuda
,
"Applies update term for LAMB optimizer"
);
}
apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu
0 → 100644
View file @
19892f1d
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
// Another possibility:
// #include <torch/all.h>
#include <assert.h>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#define BLOCK_SIZE 512
#define ILP 4
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
template
<
typename
FROM_T
,
typename
TO_T
>
__device__
void
convert
(
const
FROM_T
vi
,
TO_T
&
vo
)
{
vo
=
static_cast
<
TO_T
>
(
vi
);
}
template
<
>
__device__
void
convert
(
const
float
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
vi
;
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
float
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
static_cast
<
float
>
(
t
.
as_half
);
}
template
<
>
__device__
void
convert
(
const
at
::
Half
vi
,
uint8_t
&
vo
)
{
union
S
{
float
as_float
;
int
as_int
;
};
S
s
;
s
.
as_float
=
static_cast
<
float
>
(
vi
);
s
.
as_int
=
s
.
as_int
&
0xFF800000
;
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_half
=
static_cast
<
at
::
Half
>
(
vi
+
s
.
as_float
/
8.0
f
);
vo
=
t
.
as_byte
[
1
];
}
template
<
>
__device__
void
convert
(
const
uint8_t
vi
,
at
::
Half
&
vo
)
{
union
T
{
at
::
Half
as_half
;
uint8_t
as_byte
[
2
];
};
T
t
;
t
.
as_byte
[
0
]
=
0
;
t
.
as_byte
[
1
]
=
vi
;
vo
=
t
.
as_half
;
}
typedef
enum
{
MOMENT_MODE_0
=
0
,
// L2 regularization mode
MOMENT_MODE_1
=
1
// Decoupled weight decay mode
}
adamMode_t
;
template
<
typename
T
,
typename
GRAD_T
,
typename
MATH_T
>
struct
DistOptLAMBStage1Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
5
>&
tl
,
const
MATH_T
*
per_tensor_beta1
,
const
MATH_T
*
per_tensor_beta2
,
const
MATH_T
*
per_tensor_beta3
,
const
int
*
per_tensor_bias_correction
,
const
int
step
,
const
MATH_T
*
per_tensor_epsilon
,
adamMode_t
mode
,
const
MATH_T
*
per_tensor_decay
,
const
MATH_T
global_grad_norm
,
const
MATH_T
max_global_grad_norm
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
MATH_T
clipped_global_grad_norm
=
global_grad_norm
>
max_global_grad_norm
?
global_grad_norm
/
max_global_grad_norm
:
(
MATH_T
)
1.0
;
MATH_T
beta1
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta2
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta3
=
per_tensor_beta1
[
tensor_num
];
MATH_T
beta1_correction
,
beta2_correction
;
if
(
per_tensor_bias_correction
[
tensor_num
]
==
1
)
{
beta1_correction
=
1
-
pow
(
beta1
,
(
MATH_T
)
step
);
beta2_correction
=
1
-
pow
(
beta2
,
(
MATH_T
)
step
);
}
else
{
beta1_correction
=
(
MATH_T
)
1.0
;
beta2_correction
=
(
MATH_T
)
1.0
;
}
MATH_T
epsilon
=
per_tensor_epsilon
[
tensor_num
];
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
MATH_T
*
u
=
(
MATH_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
u
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
g
)
&&
is_aligned
(
p
)
&&
is_aligned
(
m
)
&&
is_aligned
(
v
))
{
GRAD_T
l_g
[
ILP
];
T
l_p
[
ILP
];
T
l_m
[
ILP
];
T
l_v
[
ILP
];
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
l_g
,
g
,
0
,
i_start
);
if
(
decay
!=
0
)
load_store
(
l_p
,
p
,
0
,
i_start
);
load_store
(
l_m
,
m
,
0
,
i_start
);
load_store
(
l_v
,
v
,
0
,
i_start
);
// unpack
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_g
[
ii
]
=
l_g
[
ii
];
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
l_p
[
ii
];
}
r_m
[
ii
]
=
l_m
[
ii
];
r_v
[
ii
]
=
l_v
[
ii
];
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
l_m
[
ii
]
=
r_m
[
ii
];
l_v
[
ii
]
=
r_v
[
ii
];
}
// store
load_store
(
u
,
r_p
,
i_start
,
0
);
load_store
(
m
,
l_m
,
i_start
,
0
);
load_store
(
v
,
l_v
,
i_start
,
0
);
}
}
else
{
// see note in multi_tensor_scale_kernel.cu
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_g
[
ILP
];
MATH_T
r_p
[
ILP
];
MATH_T
r_m
[
ILP
];
MATH_T
r_v
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_g
[
ii
]
=
g
[
i
];
// special ?optimization? for lamb stage 1
if
(
decay
==
0
)
{
r_p
[
ii
]
=
MATH_T
(
0
);
}
else
{
r_p
[
ii
]
=
p
[
i
];
}
r_m
[
ii
]
=
m
[
i
];
r_v
[
ii
]
=
v
[
i
];
}
else
{
r_g
[
ii
]
=
MATH_T
(
0
);
r_p
[
ii
]
=
MATH_T
(
0
);
r_m
[
ii
]
=
MATH_T
(
0
);
r_v
[
ii
]
=
MATH_T
(
0
);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
if
(
mode
==
MOMENT_MODE_0
)
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
// L2 on scaled grad
scaled_grad
=
scaled_grad
+
decay
*
r_p
[
ii
];
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
next_m_unbiased
/
denom
;
}
else
{
MATH_T
scaled_grad
=
r_g
[
ii
]
/
clipped_global_grad_norm
;
r_m
[
ii
]
=
r_m
[
ii
]
*
beta1
+
beta3
*
scaled_grad
;
r_v
[
ii
]
=
r_v
[
ii
]
*
beta2
+
(
1
-
beta2
)
*
scaled_grad
*
scaled_grad
;
MATH_T
next_m_unbiased
=
r_m
[
ii
]
/
beta1_correction
;
MATH_T
next_v_unbiased
=
r_v
[
ii
]
/
beta2_correction
;
MATH_T
denom
=
sqrtf
(
next_v_unbiased
)
+
epsilon
;
r_p
[
ii
]
=
(
next_m_unbiased
/
denom
)
+
(
decay
*
r_p
[
ii
]);
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
u
[
i
]
=
r_p
[
ii
];
m
[
i
]
=
r_m
[
ii
];
v
[
i
]
=
r_v
[
ii
];
}
}
}
}
}
};
// Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
// It computes new parameter value.
template
<
typename
T
,
typename
GRAD_T
,
typename
MATH_T
>
struct
DistOptLAMBStage2Functor
{
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>&
tl
,
const
MATH_T
*
per_tensor_param_norm
,
const
MATH_T
*
per_tensor_update_norm
,
const
MATH_T
learning_rate
,
const
MATH_T
*
per_tensor_decay
,
bool
use_nvlamb
)
{
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
MATH_T
decay
=
per_tensor_decay
[
tensor_num
];
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// otherwise, only apply to those with non-zero weight decay
if
(
use_nvlamb
||
(
decay
!=
(
MATH_T
)
0.0
))
{
MATH_T
param_norm
=
per_tensor_param_norm
[
tensor_num
];
MATH_T
update_norm
=
per_tensor_update_norm
[
tensor_num
];
ratio
=
(
update_norm
!=
(
MATH_T
)
0.0
&&
param_norm
!=
(
MATH_T
)
0.0
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
MATH_T
*
update
=
(
MATH_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
GRAD_T
*
p_copy
=
(
GRAD_T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
p_copy
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
p
)
&&
is_aligned
(
update
))
{
T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
GRAD_T
r_p_copy
[
ILP
];
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
load_store
(
r_p
,
p
,
0
,
i_start
);
load_store
(
r_update
,
update
,
0
,
i_start
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
static_cast
<
MATH_T
>
(
r_p
[
ii
])
-
(
ratio
*
r_update
[
ii
]);
convert
(
r_p
[
ii
],
r_p_copy
[
ii
]);
}
load_store
(
p
,
r_p
,
i_start
,
0
);
load_store
(
p_copy
,
r_p_copy
,
i_start
,
0
);
}
}
else
{
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
MATH_T
r_p
[
ILP
];
MATH_T
r_update
[
ILP
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
r_p
[
ii
]
=
p
[
i
];
r_update
[
ii
]
=
update
[
i
];
}
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
r_p
[
ii
]
=
r_p
[
ii
]
-
(
ratio
*
r_update
[
ii
]);
}
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
{
p
[
i
]
=
r_p
[
ii
];
convert
(
r_p
[
ii
],
p_copy
[
i
]);
}
}
}
}
}
};
void
multi_tensor_lamb_compute_update_term_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_beta1
,
at
::
Tensor
per_tensor_beta2
,
at
::
Tensor
per_tensor_beta3
,
at
::
Tensor
per_tensor_bias_correction
,
const
int
step
,
at
::
Tensor
per_tensor_epsilon
,
const
int
mode
,
at
::
Tensor
per_tensor_decay
,
const
float
global_grad_norm
,
const
float
max_global_grad_norm
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistOptLAMBStage1Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_beta1
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta2
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_beta3
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_bias_correction
.
DATA_PTR
<
int
>
(),
step
,
per_tensor_epsilon
.
DATA_PTR
<
scalar_t_2
>
(),
(
adamMode_t
)
mode
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
(
scalar_t_2
)
global_grad_norm
,
(
scalar_t_2
)
max_global_grad_norm
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
void
multi_tensor_lamb_update_weights_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
per_tensor_param_norm
,
at
::
Tensor
per_tensor_update_norm
,
const
float
learning_rate
,
at
::
Tensor
per_tensor_decay
,
bool
use_nvlamb
)
{
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_HALF_AND_BYTE
(
tensor_lists
[
2
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
2
,
"lamb_stage_2"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
DistOptLAMBStage2Functor
<
scalar_t_0
,
scalar_t_1
,
scalar_t_2
>
(),
per_tensor_param_norm
.
DATA_PTR
<
scalar_t_2
>
(),
per_tensor_update_norm
.
DATA_PTR
<
scalar_t_2
>
(),
(
scalar_t_2
)
learning_rate
,
per_tensor_decay
.
DATA_PTR
<
scalar_t_2
>
(),
use_nvlamb
);
)))
AT_CUDA_CHECK
(
cudaGetLastError
());
}
apex/contrib/optimizers/distributed_fused_lamb.py
0 → 100644
View file @
19892f1d
import
math
import
torch
import
importlib
import
amp_C
from
apex.multi_tensor_apply
import
multi_tensor_applier
class
DistributedFusedLAMB
(
torch
.
optim
.
Optimizer
):
"""Implements LAMB algorithm.
Currently GPU-only. Requires Apex to be installed via
``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
This version of fused LAMB implements 2 fusions.
* Fusion of the LAMB update's elementwise operations
* A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
:class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
...
opt.step()
:class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
you may choose any ``opt_level``::
opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
...
opt.step()
In general, ``opt_level="O1"`` is recommended.
LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups.
lr (float, optional): learning rate. (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its norm. (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability. (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (boolean, optional): whether to use the AMSGrad variant of this
algorithm from the paper `On the Convergence of Adam and Beyond`_
NOT SUPPORTED now! (default: False)
adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
True for decoupled weight decay(also known as AdamW) (default: True)
grad_averaging (bool, optional): whether apply (1-beta2) to grad when
calculating running averages of gradient. (default: True)
set_grad_none (bool, optional): whether set grad to None when zero_grad()
method is called. (default: True)
max_grad_norm (float, optional): value used to clip global grad norm
(default: 1.0)
use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
weight decay parameter (default: False)
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
.. _On the Convergence of Adam and Beyond:
https://openreview.net/forum?id=ryQu7f-RZ
"""
def
__init__
(
self
,
params
,
lr
=
1e-3
,
bias_correction
=
True
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-8
,
eps_inside_sqrt
=
False
,
weight_decay
=
0.
,
max_grad_norm
=
0.
,
amsgrad
=
False
,
use_mt
=
False
,
amp_scale_adjustment
=
1.0
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
):
global
fused_adam_cuda
fused_adam_cuda
=
importlib
.
import_module
(
"fused_adam_cuda"
)
# FIXME: Import multi_tensor_lamb_* kernels instead
self
.
_amp_scale_adjustment
=
amp_scale_adjustment
if
use_mt
:
raise
RuntimeError
(
'DistributedFusedLAMB does not support use_mt.'
)
if
amsgrad
:
raise
RuntimeError
(
'DistributedFusedLAMB does not support the AMSGrad variant.'
)
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
super
(
DistributedFusedAdam
,
self
).
__init__
(
params
,
defaults
)
self
.
eps_mode
=
0
if
eps_inside_sqrt
else
1
self
.
_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
_has_overflow
=
False
import
amp_C
self
.
multi_tensor_lamb_compute_update_term
=
amp_C
.
multi_tensor_distopt_lamb_
import
amp_C
self
.
multi_tensor_l2norm
=
amp_C
.
multi_tensor_l2norm
# Skip buffer
self
.
_dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
self
.
multi_tensor_lamb
=
amp_C
.
multi_tensor_lamb
self
.
_last_step
=
False
self
.
_overlap_reductions
=
overlap_reductions
self
.
_global_scale
=
None
self
.
_num_blocks
=
dwu_num_blocks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_L2_grad_norm
=
None
self
.
_group_size
=
torch
.
cuda
.
device_count
()
if
dwu_group_size
<=
0
else
dwu_group_size
self
.
_world_size
=
torch
.
distributed
.
get_world_size
()
self
.
_num_groups
=
self
.
_world_size
//
self
.
_group_size
self
.
_rank_in_group
=
torch
.
distributed
.
get_rank
()
%
self
.
_group_size
p_offset
=
0
p_i
=
0
self
.
_model_params
=
[]
self
.
_grads_info
=
[]
self
.
_grad_accs
=
[]
self
.
_group_properties
=
[]
self
.
_param_state
=
None
self
.
_param_group
=
None
for
group
in
self
.
param_groups
:
if
self
.
_param_group
is
None
:
self
.
_param_group
=
group
prev
=
None
beta1
,
beta2
=
group
[
'betas'
]
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
self
.
_model_params
.
append
(
p
)
self
.
_group_properties
.
append
((
group
[
'weight_decay'
],
1
if
group
[
'bias_correction'
]
else
0
,
beta1
,
beta2
,
1.0
-
beta1
if
group
[
'grad_averaging'
]
else
1.0
,
group
[
'eps'
]
))
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
if
self
.
_param_state
is
None
:
self
.
_param_state
=
state
p_grads_size
=
p
.
numel
()
def
wrapper
(
param
,
param_i
,
param_grads_size
,
param_offset
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
self
.
_do_overlapped_reduction
(
param_i
,
param_grads_size
,
param_offset
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
wrapper
(
p
,
p_i
,
p_grads_size
,
p_offset
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
p_i
+=
1
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
self
.
_grads_fp16
,
self
.
_grads_fp32
=
[],
[]
if
self
.
_overlap_reductions
:
self
.
_current_block
=
self
.
_num_blocks
self
.
_net_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
print
(
"self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d"
%
(
self
.
_net_total_param_size
,
self
.
_total_param_size
,
dwu_min_page_size
,
self
.
_block_size
,
self
.
_chunk_size
,
self
.
_shard_size
))
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
print
(
self
.
_low_param_i
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
self
.
_fp32_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_m
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_v
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
self
.
_fp32_u
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
# FIXME: Rethink fp16 label since it's either uint8 or fp16
self
.
_fp16_p
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_fp16_g
=
torch
.
zeros
([
self
.
_mega_shard_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_individual_flat_grads
=
[]
for
p_i
,
(
grads_info
,
p
)
in
enumerate
(
zip
(
self
.
_grads_info
,
self
.
_model_params
)):
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
grads_info
[
"param_offset"
]:
grads_info
[
"param_offset"
]
+
grads_info
[
"param_grads_size"
]].
view_as
(
p
))
def
_flat_split
(
p
):
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_shards
=
[[
__shardify
(
chunk
)
for
chunk
in
chunks
]
for
chunks
in
list_of_list_of_chunks
]
return
list_of_blocks
,
list_of_list_of_chunks
,
list_of_list_of_list_of_shards
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_num_chunks
*
self
.
_shard_size
:(
block_id
+
1
)
*
self
.
_num_chunks
*
self
.
_shard_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_mega_shards
=
__shardify
(
p
)
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
return
[
p
[
block_id
*
packed_block_size
:(
block_id
+
1
)
*
packed_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__packed_chunkify
(
p
):
# in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
return
[
p
[
chunk_id
*
self
.
_shard_size
:(
chunk_id
+
1
)
*
self
.
_shard_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp32_u_blocks
,
self
.
_fp32_u_chunks
=
_packed_split
(
self
.
_fp32_u
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
# This paragraph does two things:
# 1) Copy model parameters into master buffer
# 2) Create tensor lists for unpacking new parameter tensor after all-gather
self
.
_packed_flat_to_model_params_fp16
=
[]
self
.
_packed_flat_to_model_params_fp32
=
[]
self
.
_model_params_num
=
len
(
self
.
_model_params
)
self
.
_contrib_tensor_list
=
[]
self
.
_contrib_min_param_i
,
self
.
_contrib_max_param_i
=
-
1
,
-
1
self
.
_contrib_update_frag_for_norm
=
[]
self
.
_contrib_model_param_for_norm_fp16
=
[]
self
.
_contrib_model_param_for_norm_fp32
=
[]
self
.
_contrib_model_param_for_norm_is_fp16
=
[]
self
.
_contrib_group_properties
=
[]
for
shard_id
in
range
(
self
.
_group_size
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
flat_shard_start
=
(((
block_id
*
self
.
_num_chunks
+
chunk_id
)
*
self
.
_group_size
)
+
shard_id
)
*
self
.
_shard_size
flat_shard_end
=
flat_shard_start
+
self
.
_shard_size
for
param_i
,
(
p
,
grads_info
,
group_props
)
in
enumerate
(
zip
(
self
.
_model_params
,
self
.
_grads_info
,
self
.
_group_properties
)):
flat_grad_start
=
grads_info
[
"param_offset"
]
flat_grad_end
=
flat_grad_start
+
grads_info
[
"param_grads_size"
]
clipped_start
=
(
lambda
a
,
b
:
a
if
a
>
b
else
b
)(
flat_grad_start
,
flat_shard_start
)
clipped_end
=
(
lambda
a
,
b
:
a
if
a
<
b
else
b
)(
flat_grad_end
,
flat_shard_end
)
if
clipped_start
<
clipped_end
:
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_chunks
[
shard_id
][
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
if
model_param_fragment
.
dtype
==
torch
.
float16
:
self
.
_packed_flat_to_model_params_fp16
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
else
:
self
.
_packed_flat_to_model_params_fp32
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
if
shard_id
==
self
.
_rank_in_group
:
# copy model parameters into master buffer
master_param_fragment
=
self
.
_fp32_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_m_fragment
=
self
.
_fp32_m_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_v_fragment
=
self
.
_fp32_v_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_u_fragment
=
self
.
_fp32_u_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_g_fragment
=
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
opti_state_p_fragment
=
self
.
_fp16_p_chunks
[
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
print
(
"model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s"
%
(
str
(
model_param_fragment
.
size
()),
str
(
new_param_packed_fragment
.
size
()),
str
(
master_param_fragment
.
size
())))
master_param_fragment
.
copy_
(
model_param_fragment
)
self
.
_contrib_group_properties
.
append
(
group_props
)
self
.
_contrib_tensor_list
.
append
((
master_param_fragment
,
opti_state_m_fragment
,
opti_state_v_fragment
,
opti_state_u_fragment
,
opti_state_g_fragment
,
opti_state_p_fragment
))
# p, m, v, u, g, p_copy
self
.
_contrib_update_frag_for_norm
.
append
(
opti_state_u_fragment
)
if
p
.
dtype
==
torch
.
float16
:
self
.
_contrib_model_param_for_norm_fp16
.
append
(
p
)
else
:
self
.
_contrib_model_param_for_norm_fp32
.
append
(
p
)
self
.
_contrib_model_param_for_norm_is_fp16
.
append
(
True
if
p
.
dtype
==
torch
.
float16
else
False
)
if
self
.
_contrib_min_param_i
<
0
:
self
.
_contrib_min_param_i
=
param_i
self
.
_contrib_max_param_i
=
param_i
self
.
_contrib_model_param_for_norm_num
=
len
(
self
.
_contrib_model_param_for_norm_is_fp16
)
if
len
(
self
.
_contrib_model_param_for_norm_fp16
)
==
0
:
len
(
self
.
_contrib_model_param_for_norm_fp16
=
None
if
len
(
self
.
_contrib_model_param_for_norm_fp32
)
==
0
:
len
(
self
.
_contrib_model_param_for_norm_fp32
=
None
self
.
_contrib_model_param_for_norm_is_fp16
=
torch
.
tensor
([
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_contrib_model_param_for_norm_is_fp32
=
torch
.
tensor
([
not
is_fp16
for
is_fp16
in
self
.
_contrib_model_param_for_norm_is_fp16
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
self
.
_contrib_model_param_for_norm_is_fp16
=
None
p
,
m
,
v
,
u
,
g
,
p_copy
=
list
(
zip
(
*
self
.
_contrib_tensor_list
))
self
.
_contrib_compute_update_term_tensor_list
=
[
g
,
p
,
m
,
v
,
u
]
self
.
_contrib_update_weights_tensor_list
=
[
u
,
p
,
p_copy
]
math_type
=
self
.
_fp32_u
.
dtype
decay
,
bias_correction
,
beta1
,
beta2
,
beta3
,
epsilon
=
list
(
zip
(
*
self
.
_contrib_group_properties
))
self
.
_contrib_beta1
=
torch
.
tensor
(
beta1
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_beta2
=
torch
.
tensor
(
beta2
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_beta3
=
torch
.
tensor
(
beta3
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_bias_correction
=
torch
.
tensor
(
bias_correction
,
dtype
=
torch
.
int
,
device
=
'cuda'
)
self
.
_contrib_epsilon
=
torch
.
tensor
(
epsilon
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_contrib_weight_decay
=
torch
.
tensor
(
decay
,
dtype
=
math_type
,
device
=
'cuda'
)
self
.
_packed_flat_to_model_params_fp16
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp16
))
if
len
(
self
.
_packed_flat_to_model_params_fp16
)
>
0
else
None
self
.
_packed_flat_to_model_params_fp32
=
list
(
zip
(
*
self
.
_packed_flat_to_model_params_fp32
))
if
len
(
self
.
_packed_flat_to_model_params_fp32
)
>
0
else
None
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ag_pg
=
dwu_num_ag_pg
if
self
.
_num_groups
>
1
:
self
.
_ar_pg
=
[]
for
dev_i
in
range
(
self
.
_group_size
):
ranks
=
[
dev_i
+
j
*
self
.
_group_size
for
j
in
range
(
self
.
_num_groups
)]
for
i
in
range
(
self
.
_num_ar_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ar_pg
.
append
(
grp
)
self
.
_ar_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ar_pg
)]
for
ar_pg
in
self
.
_ar_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ar_pg
)
rs_ranks
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
rs_ranks
.
append
([
group_i
*
self
.
_group_size
+
j
for
j
in
range
(
self
.
_group_size
)])
self
.
_rs_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_rs_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_rs_pg
.
append
(
grp
)
l2_grad_norm_pg
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_l2_grad_norm_pg
=
l2_grad_norm_pg
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_rs_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_rs_pg
)]
for
rs_pg
in
self
.
_rs_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
rs_pg
)
if
self
.
_num_ag_pg
==
0
:
self
.
_ag_pg
=
self
.
_rs_pg
self
.
_ag_st
=
self
.
_rs_st
self
.
_num_ag_pg
=
self
.
_num_rs_pg
else
:
self
.
_ag_pg
=
[]
for
group_i
in
range
(
self
.
_num_groups
):
ranks
=
rs_ranks
[
group_i
]
for
i
in
range
(
self
.
_num_ag_pg
):
grp
=
torch
.
distributed
.
new_group
(
ranks
=
ranks
)
if
torch
.
distributed
.
get_rank
()
in
ranks
:
self
.
_ag_pg
.
append
(
grp
)
self
.
_ag_st
=
[
torch
.
cuda
.
Stream
()
for
_
in
range
(
self
.
_num_ag_pg
)]
for
ag_pg
in
self
.
_ag_pg
:
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
group
=
ag_pg
)
self
.
_l2_grad_norm_st
=
torch
.
cuda
.
Stream
()
self
.
_completion_st
=
torch
.
cuda
.
Stream
()
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
import
inspect
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
def
set_last_step
(
self
,
last_step
):
self
.
_last_step
=
last_step
def
_get_flush_block
(
self
):
flush_block
=
[]
if
self
.
_current_block
>
0
and
self
.
_grads_generated
[
self
.
_low_param_i
[
self
.
_current_block
-
1
]]:
num_grads
=
len
(
self
.
_grads_generated
)
contiguous_idx
=
num_grads
while
contiguous_idx
>
0
and
self
.
_grads_generated
[
contiguous_idx
-
1
]:
contiguous_idx
-=
1
if
contiguous_idx
<
num_grads
and
self
.
_grads_info
[
contiguous_idx
][
"param_offset"
]
<=
(
self
.
_current_block
-
1
)
*
self
.
_block_size
:
self
.
_current_block
-=
1
start
=
self
.
_current_block
*
self
.
_block_size
end
=
(
self
.
_current_block
+
1
)
*
self
.
_block_size
flush_block
=
[
start
,
end
]
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
True
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
# Compute L2 grad norm
if
block_id
==
0
:
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
# Since the packed format is contiguous after reductions, only one norm is needed
l2_grad_norm_sq
=
torch
.
empty
([
1
],
device
=
'cuda'
)
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
().
item
()
def
__launch_step_kernel
(
self
,
p
,
p_copy
,
m
,
v
,
g
):
combined_scale
=
self
.
_global_scale
if
self
.
_param_group
[
'max_grad_norm'
]
>
0
and
math
.
isfinite
(
self
.
L2_grad_norm
):
combined_scale
=
self
.
_param_group
[
'max_grad_norm'
]
/
(
self
.
L2_grad_norm
/
self
.
_global_scale
+
1e-6
)
combined_scale
=
self
.
_global_scale
/
min
(
1
,
combined_scale
)
bias_correction
=
1
if
self
.
_param_group
[
'bias_correction'
]
else
0
beta1
,
beta2
=
self
.
_param_group
[
'betas'
]
fused_adam_cuda
.
reversible_adam
(
p
,
p_copy
,
m
,
v
,
g
,
self
.
_param_group
[
'lr'
],
beta1
,
beta2
,
self
.
_param_group
[
'eps'
],
combined_scale
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
eps_mode
,
bias_correction
,
self
.
_param_group
[
'weight_decay'
])
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
gnorm_fp16
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp16
],
True
)
gnorm_fp32
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp32
],
True
)
gnorm
=
torch
.
empty
(
size
=
[
self
.
_contrib_model_param_for_norm_num
],
dtype
=
torch
.
bool
,
device
=
'cuda'
)
gnorm
.
masked_scatter
(
self
.
_contrib_model_param_for_norm_is_fp16
,
gnorm_fp16
)
gnorm
.
masked_scatter
(
self
.
_contrib_model_param_for_norm_is_fp32
,
gnorm_fp32
)
elif
self
.
_contrib_model_param_for_norm_fp16
is
not
None
:
gnorm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp16
],
True
)
elif
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
gnorm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_model_param_for_norm_fp32
],
True
)
return
gnorm
def
__compute_contrib_update_norm
(
self
):
l2_norm
=
torch
.
zeros
(
size
=
[
self
.
_model_params_num
],
dtype
=
torch
.
float32
,
device
=
'cuda'
)
local_contrib_l2_norm
=
multi_tensor_applier
(
self
.
multi_tensor_l2norm
,
self
.
_dummy_overflow_buf
,
[
self
.
_contrib_update_frag_for_norm
],
True
)
**
2
contrib_l2_norm
=
l2_norm
[
self
.
_contrib_min_param_i
:
self
.
_contrib_max_param_i
+
1
]
contrib_l2_norm
.
copy_
(
local_contrib_l2_norm
)
torch
.
distributed
.
allreduce
(
l2_norm
,
group
=
self
.
_ag_pg
[
0
])
return
l2_norm
def
_pipeline_step
(
self
):
# Call step kernel once per step
# Call all-gather once per step
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
param_norm
=
self
.
__compute_contrib_param_norm
()
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
multi_tensor_applier
(
self
.
multi_tensor_lamb_compute_update_term
,
self
.
_dummy_overflow_buf
,
self
.
_contrib_compute_update_term_tensor_list
,
# g, p, m, v, u
self
.
_contrib_beta1
,
self
.
_contrib_beta2
,
self
.
_contrib_beta3
,
self
.
_contrib_bias_correction
,
self
.
_param_state
[
'step'
]
+
1
,
self
.
_contrib_epsilon
,
1
,
# adam mode. FIXME: Correct value
self
.
_contrib_weight_decay
,
self
.
L2_grad_norm
,
max_grad_norm
)
upd_norm
=
self
.
__compute_contrib_update_norm
()
multi_tensor_applier
(
self
.
multi_tensor_lamb_update_weights
,
self
.
_dummy_overflow_buf
,
self
.
_contrib_update_weights_tensor_list
,
# u, p, p_copy
param_norm
,
upd_norm
,
self
.
_param_group
[
'lr'
],
self
.
_contrib_weight_decay
,
self
.
_use_nvlamb
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
True
)
def
_flatten_grad_mt
(
self
,
scale
):
if
len
(
self
.
_grads_fp16
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
list
(
zip
(
self
.
_grads_fp16
)),
scale
)
self
.
_grads_fp16
=
[]
if
len
(
self
.
_grads_fp32
)
>
0
:
self
.
_overflow_buf
.
zero_
()
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
list
(
zip
(
self
.
_grads_fp32
)),
scale
)
self
.
_grads_fp32
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param_grads_size
,
param_offset
,
param
):
# handle overlapped reductions
if
param
.
dtype
=
torch
.
float16
:
self
.
_grads_fp16
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
else
:
self
.
_grads_fp32
.
append
(
(
param
.
grad
,
self
.
_individual_flat_grads
[
param_i
])
)
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_overlap_reductions
and
not
self
.
_last_step
:
flush_block
=
self
.
_get_flush_block
()
while
flush_block
:
block_id
=
flush_block
[
0
]
//
self
.
_block_size
self
.
_pipeline_block_reductions
(
block_id
)
flush_block
=
self
.
_get_flush_block
()
def
set_global_scale
(
self
,
global_scale
):
"""Set global scale.
"""
self
.
_global_scale
=
global_scale
@
property
def
global_scale
(
self
):
return
self
.
_global_scale
@
property
def
has_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Clears the overflow flag.
"""
has_overflow
=
self
.
_has_overflow
self
.
_has_overflow
=
False
return
has_overflow
@
property
def
peek_overflow
(
self
):
"""Check if overflows were detected by any call to step(...) method.
Does not clear overflow flag.
"""
return
self
.
_has_overflow
def
strided_check_finite
(
self
,
output_params
,
stride
=
1
,
start
=-
1
,
end
=-
1
,
clear
=
True
):
"""Strided check for overflow.
You can get status by calling has_overflow.
"""
if
start
>=
0
and
start
<
end
:
out_p
=
output_params
[
start
:
end
]
else
:
out_p
=
output_params
fused_adam_cuda
.
strided_check_finite
(
self
.
_overflow_buf
,
out_p
,
stride
,
1
if
clear
else
0
)
self
.
_has_overflow
=
False
if
self
.
_overflow_buf
.
item
()
==
0
else
True
return
self
.
_has_overflow
@
property
def
L2_grad_norm
(
self
):
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
return
self
.
_L2_grad_norm
def
complete_reductions
(
self
):
"""Complete reductions if full pipeline is not selected or overlap is not allowed.
"""
if
self
.
_last_step
:
# zero out gradients that have not been completed yet
for
param_i
,
grad_generated
in
enumerate
(
self
.
_grads_generated
):
if
not
grad_generated
:
grad_info
=
self
.
_grads_info
[
param_i
]
param_offset
=
grad_info
[
"param_offset"
]
param_size
=
grad_info
[
"param_grads_size"
]
self
.
_flat_grads
[
param_offset
:
param_offset
+
param_size
].
zero_
()
self
.
_grads_generated
[
param_i
]
=
True
if
self
.
_last_step
or
not
self
.
_overlap_reductions
:
# nothing done so far, run full pipeline after reductions
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
self
.
_pipeline_block_reductions
(
block_id
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
self
.
_current_block
=
self
.
_num_blocks
self
.
_grads_generated
=
[
False
]
*
len
(
self
.
_grads_info
)
def
step
(
self
,
closure
=
None
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
self
.
_pipeline_step
()
with
torch
.
cuda
.
stream
(
self
.
_completion_st
):
# Copy self._new_params to model params
for
p
in
self
.
_model_params
:
self
.
state
[
p
][
'step'
]
+=
1
if
self
.
_packed_flat_to_model_params_fp16
is
not
None
:
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
self
.
_packed_flat_to_model_params_fp16
)
if
self
.
_packed_flat_to_model_params_fp32
is
not
None
:
multi_tensor_applier
(
fused_adam_cuda
.
maybe_cast_mt
,
self
.
_overflow_buf
,
self
.
_packed_flat_to_model_params_fp32
)
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_completion_st
)
self
.
_reductions_works
=
[
None
]
*
self
.
_num_blocks
self
.
_allgather_works
=
[
None
]
*
self
.
_num_blocks
return
loss
setup.py
View file @
19892f1d
...
...
@@ -220,6 +220,8 @@ if "--deprecated_fused_lamb" in sys.argv:
CUDAExtension
(
name
=
'fused_lamb_cuda'
,
sources
=
[
'apex/contrib/csrc/optimizers/fused_lamb_cuda.cpp'
,
'apex/contrib/csrc/optimizers/fused_lamb_cuda_kernel.cu'
,
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb.cpp'
,
'apex/contrib/csrc/optimizers/multi_tensor_distopt_lamb_kernel.cu'
,
'csrc/multi_tensor_l2norm_kernel.cu'
],
include_dirs
=
[
os
.
path
.
join
(
this_dir
,
'csrc'
)],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,]
+
version_dependent_macros
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment