Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
6f7a8b39
Commit
6f7a8b39
authored
Jul 27, 2020
by
lcskrishna
Browse files
Merge remote-tracking branch 'rocm_upstream/master' into ifu_07272020
parents
459de22d
9c80f6d3
Changes
63
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
231 additions
and
137 deletions
+231
-137
apex/optimizers/fused_adagrad.py
apex/optimizers/fused_adagrad.py
+2
-2
apex/optimizers/fused_adam.py
apex/optimizers/fused_adam.py
+2
-2
apex/optimizers/fused_lamb.py
apex/optimizers/fused_lamb.py
+2
-2
apex/optimizers/fused_novograd.py
apex/optimizers/fused_novograd.py
+2
-2
apex/parallel/distributed.py
apex/parallel/distributed.py
+7
-6
apex/testing/__init__.py
apex/testing/__init__.py
+0
-0
apex/testing/common_utils.py
apex/testing/common_utils.py
+22
-0
csrc/layer_norm_cuda.cpp
csrc/layer_norm_cuda.cpp
+4
-2
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+14
-7
csrc/multi_tensor_adagrad.cu
csrc/multi_tensor_adagrad.cu
+8
-8
csrc/multi_tensor_adam.cu
csrc/multi_tensor_adam.cu
+10
-10
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+17
-5
csrc/multi_tensor_axpby_kernel.cu
csrc/multi_tensor_axpby_kernel.cu
+10
-10
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+25
-17
csrc/multi_tensor_lamb.cu
csrc/multi_tensor_lamb.cu
+17
-17
csrc/multi_tensor_lamb_stage_1.cu
csrc/multi_tensor_lamb_stage_1.cu
+13
-13
csrc/multi_tensor_lamb_stage_2.cu
csrc/multi_tensor_lamb_stage_2.cu
+9
-9
csrc/multi_tensor_novograd.cu
csrc/multi_tensor_novograd.cu
+9
-9
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+8
-8
csrc/multi_tensor_sgd_kernel.cu
csrc/multi_tensor_sgd_kernel.cu
+50
-8
No files found.
apex/optimizers/fused_adagrad.py
View file @
6f7a8b39
...
@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer):
...
@@ -91,7 +91,7 @@ class FusedAdagrad(torch.optim.Optimizer):
if
len
(
state
)
==
0
:
if
len
(
state
)
==
0
:
# Exponential moving average of gradient values
# Exponential moving average of gradient values
state
[
'sum'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'sum'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}
:
g_16
.
append
(
p
.
grad
.
data
)
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
p_16
.
append
(
p
.
data
)
h_16
.
append
(
state
[
'sum'
])
h_16
.
append
(
state
[
'sum'
])
...
@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer):
...
@@ -100,7 +100,7 @@ class FusedAdagrad(torch.optim.Optimizer):
p_32
.
append
(
p
.
data
)
p_32
.
append
(
p
.
data
)
h_32
.
append
(
state
[
'sum'
])
h_32
.
append
(
state
[
'sum'
])
else
:
else
:
raise
RuntimeError
(
'FusedAdagrad only support fp16 and fp32.'
)
raise
RuntimeError
(
'FusedAdagrad only support fp16
, bfloat16
and fp32.'
)
if
(
len
(
g_16
)
>
0
):
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adagrad
,
multi_tensor_applier
(
self
.
multi_tensor_adagrad
,
...
...
apex/optimizers/fused_adam.py
View file @
6f7a8b39
...
@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -130,7 +130,7 @@ class FusedAdam(torch.optim.Optimizer):
# Exponential moving average of squared gradient values
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}
:
g_16
.
append
(
p
.
grad
.
data
)
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
p_16
.
append
(
p
.
data
)
m_16
.
append
(
state
[
'exp_avg'
])
m_16
.
append
(
state
[
'exp_avg'
])
...
@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer):
...
@@ -141,7 +141,7 @@ class FusedAdam(torch.optim.Optimizer):
m_32
.
append
(
state
[
'exp_avg'
])
m_32
.
append
(
state
[
'exp_avg'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
else
:
else
:
raise
RuntimeError
(
'FusedAdam only support fp16 and fp32.'
)
raise
RuntimeError
(
'FusedAdam only support fp16
, bfloat16
and fp32.'
)
if
(
len
(
g_16
)
>
0
):
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_adam
,
multi_tensor_applier
(
self
.
multi_tensor_adam
,
...
...
apex/optimizers/fused_lamb.py
View file @
6f7a8b39
...
@@ -165,7 +165,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -165,7 +165,7 @@ class FusedLAMB(torch.optim.Optimizer):
# Exponential moving average of gradient values
# Exponential moving average of gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}
:
g_16
.
append
(
p
.
grad
.
data
)
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
p_16
.
append
(
p
.
data
)
m_16
.
append
(
state
[
'exp_avg'
])
m_16
.
append
(
state
[
'exp_avg'
])
...
@@ -176,7 +176,7 @@ class FusedLAMB(torch.optim.Optimizer):
...
@@ -176,7 +176,7 @@ class FusedLAMB(torch.optim.Optimizer):
m_32
.
append
(
state
[
'exp_avg'
])
m_32
.
append
(
state
[
'exp_avg'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
v_32
.
append
(
state
[
'exp_avg_sq'
])
else
:
else
:
raise
RuntimeError
(
'FusedLAMB only support fp16 and fp32.'
)
raise
RuntimeError
(
'FusedLAMB only support fp16
, bfloat16
and fp32.'
)
if
(
len
(
g_16
)
>
0
):
if
(
len
(
g_16
)
>
0
):
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
multi_tensor_applier
(
self
.
multi_tensor_lamb
,
...
...
apex/optimizers/fused_novograd.py
View file @
6f7a8b39
...
@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
...
@@ -142,7 +142,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
# Exponential moving average of gradient values
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
if
p
.
dtype
==
torch
.
float16
:
if
p
.
dtype
in
{
torch
.
float16
,
torch
.
bfloat16
}
:
g_16
.
append
(
p
.
grad
.
data
)
g_16
.
append
(
p
.
grad
.
data
)
p_16
.
append
(
p
.
data
)
p_16
.
append
(
p
.
data
)
m_16
.
append
(
state
[
'exp_avg'
])
m_16
.
append
(
state
[
'exp_avg'
])
...
@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
...
@@ -151,7 +151,7 @@ class FusedNovoGrad(torch.optim.Optimizer):
p_32
.
append
(
p
.
data
)
p_32
.
append
(
p
.
data
)
m_32
.
append
(
state
[
'exp_avg'
])
m_32
.
append
(
state
[
'exp_avg'
])
else
:
else
:
raise
RuntimeError
(
'FusedNovoGrad only support fp16 and fp32.'
)
raise
RuntimeError
(
'FusedNovoGrad only support fp16
, bfloat16
and fp32.'
)
# we store per weight norm as one tensor for one group/precision combination
# we store per weight norm as one tensor for one group/precision combination
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
# different from optim.Adam, we store norm here(not ^2) so we can unify calculation for norm types
...
...
apex/parallel/distributed.py
View file @
6f7a8b39
...
@@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
...
@@ -48,8 +48,8 @@ def apply_flat_dist_call(bucket, call, extra_args=None):
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
coalesced
,
bucket
)):
for
buf
,
synced
in
zip
(
bucket
,
unflatten
(
coalesced
,
bucket
)):
buf
.
copy_
(
synced
)
buf
.
copy_
(
synced
)
def
split_half_float_double
(
tensors
):
def
split_half_float_double
_bfloat16
(
tensors
):
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
]
dtypes
=
[
"torch.cuda.HalfTensor"
,
"torch.cuda.FloatTensor"
,
"torch.cuda.DoubleTensor"
,
"torch.cuda.BFloat16Tensor"
]
buckets
=
[]
buckets
=
[]
for
i
,
dtype
in
enumerate
(
dtypes
):
for
i
,
dtype
in
enumerate
(
dtypes
):
bucket
=
[
t
for
t
in
tensors
if
t
.
type
()
==
dtype
]
bucket
=
[
t
for
t
in
tensors
if
t
.
type
()
==
dtype
]
...
@@ -240,7 +240,8 @@ class DistributedDataParallel(Module):
...
@@ -240,7 +240,8 @@ class DistributedDataParallel(Module):
self
.
param_type_to_tmp_i
=
{
"torch.cuda.HalfTensor"
:
0
,
self
.
param_type_to_tmp_i
=
{
"torch.cuda.HalfTensor"
:
0
,
"torch.cuda.FloatTensor"
:
1
,
"torch.cuda.FloatTensor"
:
1
,
"torch.cuda.DoubleTensor"
:
2
}
"torch.cuda.DoubleTensor"
:
2
,
"torch.cuda.BFloat16Tensor"
:
3
}
if
multi_tensor_applier
.
available
:
if
multi_tensor_applier
.
available
:
# TODO: I really need to centralize the C++ backed imports
# TODO: I really need to centralize the C++ backed imports
...
@@ -498,7 +499,7 @@ class DistributedDataParallel(Module):
...
@@ -498,7 +499,7 @@ class DistributedDataParallel(Module):
else
:
else
:
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
grads
=
[
param
.
grad
.
data
for
param
in
self
.
module
.
parameters
()
if
param
.
grad
is
not
None
]
split_buckets
=
split_half_float_double
(
grads
)
split_buckets
=
split_half_float_double
_bfloat16
(
grads
)
# If retain_allreduce_buffers is True and delay_allreduce is False,
# If retain_allreduce_buffers is True and delay_allreduce is False,
# this will only be done during the first backward pass, ignored by the
# this will only be done during the first backward pass, ignored by the
...
@@ -578,8 +579,8 @@ class DistributedDataParallel(Module):
...
@@ -578,8 +579,8 @@ class DistributedDataParallel(Module):
if
self
.
needs_refresh
:
if
self
.
needs_refresh
:
self
.
active_i_buckets
=
[]
self
.
active_i_buckets
=
[]
self
.
buckets
=
[]
self
.
buckets
=
[]
self
.
tmp_buckets
=
[[],
[],
[]]
# [running half, float, double buckets]
self
.
tmp_buckets
=
[[],
[],
[],
[]]
# [running half, float, double
, bfloat16
buckets]
self
.
tmp_numels
=
[
0
,
0
,
0
]
self
.
tmp_numels
=
[
0
,
0
,
0
,
0
]
self
.
bucket_sizes
=
[]
self
.
bucket_sizes
=
[]
self
.
param_id_to_active_i
=
{
id
(
param
)
:
i
for
i
,
param
in
enumerate
(
param_list
)}
self
.
param_id_to_active_i
=
{
id
(
param
)
:
i
for
i
,
param
in
enumerate
(
param_list
)}
self
.
param_id_to_bucket
=
{}
self
.
param_id_to_bucket
=
{}
...
...
apex/testing/__init__.py
0 → 100644
View file @
6f7a8b39
apex/testing/common_utils.py
0 → 100644
View file @
6f7a8b39
'''
This file contains common utility functions for running the unit tests on ROCM.
'''
import
torch
import
os
import
sys
from
functools
import
wraps
import
unittest
TEST_WITH_ROCM
=
os
.
getenv
(
'APEX_TEST_WITH_ROCM'
,
'0'
)
==
'1'
## Wrapper to skip the unit tests.
def
skipIfRocm
(
fn
):
@
wraps
(
fn
)
def
wrapper
(
*
args
,
**
kwargs
):
if
TEST_WITH_ROCM
:
raise
unittest
.
SkipTest
(
"test doesn't currently work on ROCm stack."
)
else
:
fn
(
*
args
,
**
kwargs
)
return
wrapper
csrc/layer_norm_cuda.cpp
View file @
6f7a8b39
...
@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm(
...
@@ -130,7 +130,8 @@ std::vector<at::Tensor> layer_norm(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
normalized_shape
,
NULL
,
NULL
,
epsilon
);
...
@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine(
...
@@ -152,7 +153,8 @@ std::vector<at::Tensor> layer_norm_affine(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
((
input
.
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
.
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
.
scalar_type
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
...
...
csrc/layer_norm_cuda_kernel.cu
View file @
6f7a8b39
...
@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
...
@@ -172,8 +172,8 @@ void cuWelfordMuSigma2(
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(;
l
+
7
<
n2
;
l
+=
8
*
numx
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
for
(
int
k
=
0
;
k
<
8
;
k
+=
2
)
{
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
float2
curr
=
__half22float2
(
*
((
__half2
*
)(
lvals
+
l
+
k
)));
cuWelfordOnlineSum
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
float
>
(
curr
.
x
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
(
curr
.
y
,
mu
,
sigma2
,
count
);
cuWelfordOnlineSum
<
float
>
(
curr
.
y
,
mu
,
sigma2
,
count
);
}
}
}
}
for
(;
l
<
n2
;
++
l
)
{
for
(;
l
<
n2
;
++
l
)
{
...
@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
...
@@ -230,9 +230,15 @@ void cuWelfordMuSigma2(
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
template
<
typename
U
>
U
rsqrt
(
U
v
)
{
return
U
(
1
)
/
sqrt
(
v
);
return
U
(
1
)
/
sqrt
(
v
);
}
}
#if defined __HIP_PLATFORM_HCC__
__device__
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
}
#else
template
<
>
float
rsqrt
(
float
v
)
{
template
<
>
float
rsqrt
(
float
v
)
{
return
rsqrtf
(
v
);
return
rsqrtf
(
v
);
}
}
#endif
template
<
>
double
rsqrt
(
double
v
)
{
template
<
>
double
rsqrt
(
double
v
)
{
return
rsqrt
(
v
);
return
rsqrt
(
v
);
}
}
...
@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
...
@@ -293,7 +299,7 @@ void cuApplyLayerNorm(
// 1) blockDim.x == warpSize
// 1) blockDim.x == warpSize
// 2) Tensors are contiguous
// 2) Tensors are contiguous
//
//
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
SharedMemory
<
U
>
shared
;
SharedMemory
<
U
>
shared
;
U
*
buf
=
shared
.
getPointer
();
U
*
buf
=
shared
.
getPointer
();
U
mu
,
sigma2
;
U
mu
,
sigma2
;
...
@@ -531,7 +537,7 @@ void cuComputeGradInput(
...
@@ -531,7 +537,7 @@ void cuComputeGradInput(
const
T
*
gamma
,
const
T
*
gamma
,
T
*
grad_input
)
T
*
grad_input
)
{
{
for
(
auto
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
for
(
int
i1
=
blockIdx
.
y
;
i1
<
n1
;
i1
+=
gridDim
.
y
)
{
U
sum_loss1
=
U
(
0
);
U
sum_loss1
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
U
sum_loss2
=
U
(
0
);
const
U
c_mean
=
mean
[
i1
];
const
U
c_mean
=
mean
[
i1
];
...
@@ -684,7 +690,7 @@ void cuda_layer_norm(
...
@@ -684,7 +690,7 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
input
->
scalar_type
(),
0
,
"layer_norm_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
DATA_PTR
<
scalar_t_0
>
(),
output
->
DATA_PTR
<
scalar_t_0
>
(),
...
@@ -724,7 +730,8 @@ void HostLayerNormGradient(
...
@@ -724,7 +730,8 @@ void HostLayerNormGradient(
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
((
input
->
scalar_type
()
==
at
::
ScalarType
::
Half
||
input
->
scalar_type
()
==
at
::
ScalarType
::
BFloat16
)
?
at
::
ScalarType
::
Float
:
input
->
scalar_type
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -787,7 +794,7 @@ void cuda_layer_norm_gradient(
...
@@ -787,7 +794,7 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
input
->
scalar_type
(),
0
,
"cuComputeGradInput"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t_0
,
true
>
;
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
DATA_PTR
<
scalar_t_0
>
(),
dout
->
DATA_PTR
<
scalar_t_0
>
(),
...
...
csrc/multi_tensor_adagrad.cu
View file @
6f7a8b39
...
@@ -23,20 +23,20 @@ using MATH_T = float;
...
@@ -23,20 +23,20 @@ using MATH_T = float;
template
<
typename
T
>
struct
AdagradFunctor
{
template
<
typename
T
>
struct
AdagradFunctor
{
__device__
__forceinline__
void
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
operator
()(
int
chunk_size
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
*
tl
,
const
float
epsilon
,
const
float
lr
,
adagradMode_t
mode
,
const
float
epsilon
,
const
float
lr
,
adagradMode_t
mode
,
const
float
weight_decay
)
{
const
float
weight_decay
)
{
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
h
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
h
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
h
+=
chunk_idx
*
chunk_size
;
h
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda(
...
@@ -90,7 +90,7 @@ void multi_tensor_adagrad_cuda(
using
namespace
at
;
using
namespace
at
;
// Assume single type across p,g,h now
// Assume single type across p,g,h now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adagrad"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adagrad"
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
AdagradFunctor
<
scalar_t_0
>
(),
epsilon
,
lr
,
AdagradFunctor
<
scalar_t_0
>
(),
epsilon
,
lr
,
...
...
csrc/multi_tensor_adam.cu
View file @
6f7a8b39
...
@@ -26,7 +26,7 @@ struct AdamFunctor
...
@@ -26,7 +26,7 @@ struct AdamFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
TensorListMetadata
<
4
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta1_correction
,
const
float
beta1_correction
,
...
@@ -40,24 +40,24 @@ struct AdamFunctor
...
@@ -40,24 +40,24 @@ struct AdamFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
// potentially use to pass in list of scalar
// potentially use to pass in list of scalar
// int tensor_num = tl
.
start_tensor_this_launch + tensor_loc;
// int tensor_num = tl
->
start_tensor_this_launch + tensor_loc;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
...
@@ -149,7 +149,7 @@ void multi_tensor_adam_cuda(
}
}
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"adam"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_apply.cuh
View file @
6f7a8b39
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#include <ATen/AccumulateType.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <ATen/cuda/Exceptions.h>
#include <THC/THC.h>
#include "compat.h"
#include "compat.h"
#include <assert.h>
#include <assert.h>
...
@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes>
...
@@ -29,7 +30,7 @@ template<typename T, typename U, typename... ArgTypes>
__global__
void
multi_tensor_apply_kernel
(
__global__
void
multi_tensor_apply_kernel
(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_flag
,
volatile
int
*
noop_flag
,
T
tl
,
T
*
tl
,
U
callable
,
U
callable
,
ArgTypes
...
args
)
ArgTypes
...
args
)
{
{
...
@@ -56,7 +57,7 @@ void multi_tensor_apply(
...
@@ -56,7 +57,7 @@ void multi_tensor_apply(
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
{
// TODO: Print which tensor fails.
// TODO: Print which tensor fails.
bool
contiguous_memory
=
tensor_lists
[
l
][
t
].
is_contiguous
();
bool
contiguous_memory
=
(
tensor_lists
[
l
][
t
].
is_sparse
())
?
tensor_lists
[
l
][
t
].
_values
().
is_contiguous
()
:
tensor_lists
[
l
][
t
].
is_contiguous
();
#ifdef VERSION_GE_1_5
#ifdef VERSION_GE_1_5
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
));
contiguous_memory
=
(
contiguous_memory
||
tensor_lists
[
l
][
t
].
is_contiguous
(
at
::
MemoryFormat
::
ChannelsLast
));
#endif
#endif
...
@@ -78,8 +79,15 @@ void multi_tensor_apply(
...
@@ -78,8 +79,15 @@ void multi_tensor_apply(
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
].
numel
();
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
].
numel
();
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
{
if
(
tensor_lists
[
d
][
t
].
is_sparse
())
{
at
::
Tensor
dst
=
at
::
zeros
(
tensor_lists
[
d
][
t
].
sizes
(),
tensor_lists
[
d
][
t
].
options
().
layout
(
at
::
kStrided
));
dst
.
add_
(
tensor_lists
[
d
][
t
]);
tl
.
addresses
[
d
][
loc_tensor_info
]
=
dst
.
data_ptr
();
}
else
{
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
].
data_ptr
();
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
].
data_ptr
();
}
}
loc_tensor_info
++
;
loc_tensor_info
++
;
int
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
int
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
...
@@ -97,11 +105,15 @@ void multi_tensor_apply(
...
@@ -97,11 +105,15 @@ void multi_tensor_apply(
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
{
auto
storage
=
at
::
empty
(
sizeof
(
tl
),
c10
::
TensorOptions
(
at
::
kStrided
).
dtype
(
at
::
kByte
).
device
(
at
::
kCPU
).
pinned_memory
(
true
));
auto
tl_as_host_pinned_ptr
=
static_cast
<
decltype
(
tl
)
*>
(
storage
.
data_ptr
());
memcpy
(
tl_as_host_pinned_ptr
,
&
tl
,
sizeof
(
tl
));
AT_CUDA_CHECK
(
THCCachingHostAllocator_recordEvent
(
tl_as_host_pinned_ptr
,
stream
));
// using accscalar_t = acc_type<scalar_t, true>;
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
chunk_size
,
chunk_size
,
noop_flag
.
DATA_PTR
<
int
>
(),
noop_flag
.
DATA_PTR
<
int
>
(),
tl
,
tl
_as_host_pinned_ptr
,
callable
,
callable
,
args
...);
args
...);
...
...
csrc/multi_tensor_axpby_kernel.cu
View file @
6f7a8b39
...
@@ -30,7 +30,7 @@ struct AxpbyFunctor
...
@@ -30,7 +30,7 @@ struct AxpbyFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
TensorListMetadata
<
3
>
*
tl
,
float
a
,
float
a
,
float
b
,
float
b
,
int
arg_to_check
)
int
arg_to_check
)
...
@@ -39,17 +39,17 @@ struct AxpbyFunctor
...
@@ -39,17 +39,17 @@ struct AxpbyFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
y_t
*
y
=
(
y_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
y_t
*
y
=
(
y_t
*
)
tl
->
addresses
[
1
][
tensor_loc
];
y
+=
chunk_idx
*
chunk_size
;
y
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
2
][
tensor_loc
];
out_t
*
out
=
(
out_t
*
)
tl
->
addresses
[
2
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
...
@@ -138,9 +138,9 @@ void multi_tensor_axpby_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
2
][
0
].
scalar_type
(),
2
,
"multi_tensor_axpby_cuda"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_l2norm_kernel.cu
View file @
6f7a8b39
...
@@ -30,7 +30,7 @@ struct L2NormFunctor
...
@@ -30,7 +30,7 @@ struct L2NormFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
TensorListMetadata
<
1
>
*
tl
,
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
bool
per_tensor
,
bool
per_tensor
,
...
@@ -40,11 +40,11 @@ struct L2NormFunctor
...
@@ -40,11 +40,11 @@ struct L2NormFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -103,7 +103,7 @@ struct L2NormFunctor
...
@@ -103,7 +103,7 @@ struct L2NormFunctor
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
+=
final
;
output
[
blockIdx
.
x
]
+=
final
;
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
output_per_tensor
[(
tl
->
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
}
}
};
};
...
@@ -115,7 +115,7 @@ struct MaxNormFunctor
...
@@ -115,7 +115,7 @@ struct MaxNormFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
1
>
&
tl
,
TensorListMetadata
<
1
>
*
tl
,
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
bool
per_tensor
,
bool
per_tensor
,
...
@@ -125,11 +125,11 @@ struct MaxNormFunctor
...
@@ -125,11 +125,11 @@ struct MaxNormFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
x_t
*
x
=
(
x_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
x
+=
chunk_idx
*
chunk_size
;
x
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -188,13 +188,17 @@ struct MaxNormFunctor
...
@@ -188,13 +188,17 @@ struct MaxNormFunctor
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
output
[
blockIdx
.
x
]
=
fmaxf
(
fabsf
(
output
[
blockIdx
.
x
]),
fabsf
(
final
));
if
(
per_tensor
)
if
(
per_tensor
)
output_per_tensor
[(
tl
.
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
output_per_tensor
[(
tl
->
start_tensor_this_launch
+
tensor_loc
)
*
max_chunks_per_tensor
+
chunk_idx
]
=
final
;
}
}
}
}
};
};
__global__
void
cleanup
(
__global__
void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__
(
1024
)
#endif
cleanup
(
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
float
*
ret
,
float
*
ret
,
...
@@ -231,7 +235,11 @@ __global__ void cleanup(
...
@@ -231,7 +235,11 @@ __global__ void cleanup(
}
}
}
}
__global__
void
cleanup_v2
(
__global__
void
#ifdef __HIP_PLATFORM_HCC__
__launch_bounds__
(
1024
)
#endif
cleanup_v2
(
float
*
output
,
float
*
output
,
float
*
output_per_tensor
,
float
*
output_per_tensor
,
float
*
ret
,
float
*
ret
,
...
@@ -322,7 +330,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
...
@@ -322,7 +330,7 @@ std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
ret_per_tensor
=
at
::
empty
({
0
},
float_options
);
}
}
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -391,7 +399,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -391,7 +399,7 @@ void multi_tensor_norm_out_cuda(
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
output_per_tensor
=
at
::
zeros
({
ntensors
*
max_chunks_per_tensor
},
float_options
);
if
(
norm_type
==
0
)
{
if
(
norm_type
==
0
)
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_maxnorm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
@@ -405,7 +413,7 @@ void multi_tensor_norm_out_cuda(
...
@@ -405,7 +413,7 @@ void multi_tensor_norm_out_cuda(
max_chunks_per_tensor
);)
max_chunks_per_tensor
);)
}
}
else
{
else
{
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_lamb.cu
View file @
6f7a8b39
...
@@ -43,7 +43,7 @@ struct LAMBStage1Functor
...
@@ -43,7 +43,7 @@ struct LAMBStage1Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
4
>
&
tl
,
TensorListMetadata
<
4
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta3
,
const
float
beta3
,
...
@@ -59,22 +59,22 @@ struct LAMBStage1Functor
...
@@ -59,22 +59,22 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
float
clipped_global_grad_norm
=
(
*
global_grad_norm
)
>
max_global_grad_norm
?
(
*
global_grad_norm
)
/
max_global_grad_norm
:
1.0
f
;
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -236,7 +236,7 @@ struct LAMBStage2Functor
...
@@ -236,7 +236,7 @@ struct LAMBStage2Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
,
const
float
learning_rate
,
...
@@ -247,10 +247,10 @@ struct LAMBStage2Functor
...
@@ -247,10 +247,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// nvlamb: apply adaptive learning rate to all parameters
...
@@ -262,10 +262,10 @@ struct LAMBStage2Functor
...
@@ -262,10 +262,10 @@ struct LAMBStage2Functor
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
}
T
*
update
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
update
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -372,7 +372,7 @@ void multi_tensor_lamb_cuda(
...
@@ -372,7 +372,7 @@ void multi_tensor_lamb_cuda(
// We now in-place modify grad to store update before compute its norm
// We now in-place modify grad to store update before compute its norm
// Generally this is not a issue since people modify grad in step() method all the time
// Generally this is not a issue since people modify grad in step() method all the time
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
// We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
multi_tensor_apply
<
4
>
(
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
@@ -395,7 +395,7 @@ void multi_tensor_lamb_cuda(
...
@@ -395,7 +395,7 @@ void multi_tensor_lamb_cuda(
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
grad_param_list
(
tensor_lists
.
begin
(),
tensor_lists
.
begin
()
+
2
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_1.cu
View file @
6f7a8b39
...
@@ -20,7 +20,7 @@ struct LAMBStage1Functor
...
@@ -20,7 +20,7 @@ struct LAMBStage1Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
5
>
&
tl
,
TensorListMetadata
<
5
>
*
tl
,
const
float
*
per_tensor_decay
,
const
float
*
per_tensor_decay
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
...
@@ -33,26 +33,26 @@ struct LAMBStage1Functor
...
@@ -33,26 +33,26 @@ struct LAMBStage1Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
decay
=
per_tensor_decay
[
tensor_num
];
float
decay
=
per_tensor_decay
[
tensor_num
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
GRAD_T
*
g
=
(
GRAD_T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
T
*
v
=
(
T
*
)
tl
.
addresses
[
3
][
tensor_loc
];
T
*
v
=
(
T
*
)
tl
->
addresses
[
3
][
tensor_loc
];
v
+=
chunk_idx
*
chunk_size
;
v
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
4
][
tensor_loc
];
UPD_T
*
update
=
(
UPD_T
*
)
tl
->
addresses
[
4
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -128,9 +128,9 @@ void multi_tensor_lamb_stage1_cuda(
...
@@ -128,9 +128,9 @@ void multi_tensor_lamb_stage1_cuda(
float
next_step
=
float
(
step
+
1
);
float
next_step
=
float
(
step
+
1
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta1_correction
=
1.0
f
-
std
::
pow
(
beta1
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
float
beta2_correction
=
1.0
f
-
std
::
pow
(
beta2
,
next_step
);
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
4
][
0
].
scalar_type
(),
2
,
"lamb_stage_1"
,
multi_tensor_apply
<
5
>
(
multi_tensor_apply
<
5
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_lamb_stage_2.cu
View file @
6f7a8b39
...
@@ -23,7 +23,7 @@ struct LAMBStage2Functor
...
@@ -23,7 +23,7 @@ struct LAMBStage2Functor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_param_norm
,
const
float
*
per_tensor_update_norm
,
const
float
*
per_tensor_update_norm
,
const
float
learning_rate
,
const
float
learning_rate
,
...
@@ -34,10 +34,10 @@ struct LAMBStage2Functor
...
@@ -34,10 +34,10 @@ struct LAMBStage2Functor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
MATH_T
ratio
=
learning_rate
;
MATH_T
ratio
=
learning_rate
;
// nvlamb: apply adaptive learning rate to all parameters
// nvlamb: apply adaptive learning rate to all parameters
...
@@ -49,10 +49,10 @@ struct LAMBStage2Functor
...
@@ -49,10 +49,10 @@ struct LAMBStage2Functor
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
ratio
=
(
update_norm
!=
0.0
f
&&
param_norm
!=
0.0
f
)
?
learning_rate
*
(
param_norm
/
update_norm
)
:
learning_rate
;
}
}
T
*
p
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
UPD_T
*
update
=
(
UPD_T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
UPD_T
*
update
=
(
UPD_T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
update
+=
chunk_idx
*
chunk_size
;
update
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -105,8 +105,8 @@ void multi_tensor_lamb_stage2_cuda(
...
@@ -105,8 +105,8 @@ void multi_tensor_lamb_stage2_cuda(
using
namespace
at
;
using
namespace
at
;
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"lamb_stage_2"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_novograd.cu
View file @
6f7a8b39
...
@@ -35,7 +35,7 @@ struct NovoGradFunctor
...
@@ -35,7 +35,7 @@ struct NovoGradFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
3
>
&
tl
,
TensorListMetadata
<
3
>
*
tl
,
const
float
beta1
,
const
float
beta1
,
const
float
beta2
,
const
float
beta2
,
const
float
beta3
,
const
float
beta3
,
...
@@ -51,20 +51,20 @@ struct NovoGradFunctor
...
@@ -51,20 +51,20 @@ struct NovoGradFunctor
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
tensor_num
=
tl
.
start_tensor_this_launch
+
tensor_loc
;
int
tensor_num
=
tl
->
start_tensor_this_launch
+
tensor_loc
;
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
float
grad_norm
=
per_tensor_grad_norm
[
tensor_num
];
float
grad_norm
=
per_tensor_grad_norm
[
tensor_num
];
T
*
g
=
(
T
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T
*
g
=
(
T
*
)
tl
->
addresses
[
0
][
tensor_loc
];
g
+=
chunk_idx
*
chunk_size
;
g
+=
chunk_idx
*
chunk_size
;
T
*
p
=
(
T
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T
*
p
=
(
T
*
)
tl
->
addresses
[
1
][
tensor_loc
];
p
+=
chunk_idx
*
chunk_size
;
p
+=
chunk_idx
*
chunk_size
;
T
*
m
=
(
T
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T
*
m
=
(
T
*
)
tl
->
addresses
[
2
][
tensor_loc
];
m
+=
chunk_idx
*
chunk_size
;
m
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
...
@@ -164,7 +164,7 @@ void multi_tensor_novograd_cuda(
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
multi_tensor_norm_out_cuda
(
chunk_size
,
noop_flag
,
grad_list
,
grad_norms
,
beta2
,
(
1.0
f
-
beta2
),
norm_type
);
// Assume single type across p,g,m1,m2 now
// Assume single type across p,g,m1,m2 now
DISPATCH_DOUBLE_FLOAT_AND_HALF
(
DISPATCH_DOUBLE_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"novograd"
,
multi_tensor_apply
<
3
>
(
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
6f7a8b39
...
@@ -32,21 +32,21 @@ struct ScaleFunctor
...
@@ -32,21 +32,21 @@ struct ScaleFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>
&
tl
,
TensorListMetadata
<
2
>
*
tl
,
float
scale
)
float
scale
)
{
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
->
addresses
[
0
][
tensor_loc
];
in
+=
chunk_idx
*
chunk_size
;
in
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
out_t
*
out
=
(
out_t
*
)
tl
->
addresses
[
1
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
...
@@ -121,8 +121,8 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
_AND_BFLOAT16
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
multi_tensor_apply
<
2
>
(
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
BLOCK_SIZE
,
chunk_size
,
chunk_size
,
...
...
csrc/multi_tensor_sgd_kernel.cu
View file @
6f7a8b39
...
@@ -32,7 +32,7 @@ struct SGDFunctor
...
@@ -32,7 +32,7 @@ struct SGDFunctor
__device__
__forceinline__
void
operator
()(
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
chunk_size
,
volatile
int
*
noop_gmem
,
volatile
int
*
noop_gmem
,
TensorListMetadata
<
N
>
&
tl
,
TensorListMetadata
<
N
>
*
tl
,
float
wd
,
float
wd
,
float
momentum
,
float
momentum
,
float
dampening
,
float
dampening
,
...
@@ -45,23 +45,23 @@ struct SGDFunctor
...
@@ -45,23 +45,23 @@ struct SGDFunctor
// Early exit if we don't need to do anything
// Early exit if we don't need to do anything
if
(
*
noop_gmem
)
return
;
if
(
*
noop_gmem
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
tensor_loc
=
tl
->
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
->
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
->
sizes
[
tensor_loc
];
T_grad
*
grad_in
=
(
T_grad
*
)
tl
.
addresses
[
0
][
tensor_loc
];
T_grad
*
grad_in
=
(
T_grad
*
)
tl
->
addresses
[
0
][
tensor_loc
];
grad_in
+=
chunk_idx
*
chunk_size
;
grad_in
+=
chunk_idx
*
chunk_size
;
T_weight
*
weight_in
=
(
T_weight
*
)
tl
.
addresses
[
1
][
tensor_loc
];
T_weight
*
weight_in
=
(
T_weight
*
)
tl
->
addresses
[
1
][
tensor_loc
];
weight_in
+=
chunk_idx
*
chunk_size
;
weight_in
+=
chunk_idx
*
chunk_size
;
T_weight
*
mom_in
=
(
T_weight
*
)
tl
.
addresses
[
2
][
tensor_loc
];
T_weight
*
mom_in
=
(
T_weight
*
)
tl
->
addresses
[
2
][
tensor_loc
];
mom_in
+=
chunk_idx
*
chunk_size
;
mom_in
+=
chunk_idx
*
chunk_size
;
at
::
Half
*
model_weights_out
=
nullptr
;
at
::
Half
*
model_weights_out
=
nullptr
;
if
(
N
==
4
)
if
(
N
==
4
)
{
{
model_weights_out
=
(
at
::
Half
*
)
tl
.
addresses
[
3
][
tensor_loc
];
model_weights_out
=
(
at
::
Half
*
)
tl
->
addresses
[
3
][
tensor_loc
];
model_weights_out
+=
chunk_idx
*
chunk_size
;
model_weights_out
+=
chunk_idx
*
chunk_size
;
}
}
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
...
@@ -166,6 +166,8 @@ void multi_tensor_sgd_cuda(
// 2. fp32, fp32, fp32, No
// 2. fp32, fp32, fp32, No
// 3. fp16, fp32, fp32, Yes
// 3. fp16, fp32, fp32, Yes
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
// 5. bfp16, bfp16, bfp16, No
// 6. bfp16, fp32, fp32, Yes
// It's easier to hardcode these possibilities than to use
// It's easier to hardcode these possibilities than to use
// switches etc. to handle the cross-product of cases where
// switches etc. to handle the cross-product of cases where
// we don't want the majority of them.
// we don't want the majority of them.
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
...
@@ -268,6 +270,46 @@ void multi_tensor_sgd_cuda(
wd_after_momentum
,
wd_after_momentum
,
scale
);
scale
);
}
}
// Case 5. bfp16, bfp16, bfp16, No
else
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
BFloat16
&&
num_tensors
==
3
)
{
multi_tensor_apply
<
3
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
3
,
at
::
BFloat16
,
at
::
BFloat16
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
// Case 6. bfp16, fp32, fp32, Yes
else
if
(
grad_type
==
at
::
ScalarType
::
BFloat16
&&
weight_type
==
at
::
ScalarType
::
Float
&&
num_tensors
==
4
)
{
multi_tensor_apply
<
4
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
SGDFunctor
<
4
,
at
::
BFloat16
,
float
>
(),
wd
,
momentum
,
dampening
,
lr
,
nesterov
,
first_run
,
wd_after_momentum
,
scale
);
}
else
else
{
{
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
AT_ERROR
(
"multi_tensor_sgd only supports some combinations of gradient & weight types. Given: "
,
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment