Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
6763a8be
Commit
6763a8be
authored
Feb 18, 2019
by
Michael Carilli
Browse files
Reworked multi tensor apply, added tests
parent
889d1712
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
423 additions
and
204 deletions
+423
-204
apex/amp/__init__.py
apex/amp/__init__.py
+0
-1
apex/amp/_initialize.py
apex/amp/_initialize.py
+3
-3
apex/amp/frontend.py
apex/amp/frontend.py
+6
-6
apex/amp/multi_tensor_apply.py
apex/amp/multi_tensor_apply.py
+0
-137
apex/amp/scaler.py
apex/amp/scaler.py
+3
-3
apex/fp16_utils/fp16_optimizer.py
apex/fp16_utils/fp16_optimizer.py
+9
-0
csrc/amp_C_frontend.cpp
csrc/amp_C_frontend.cpp
+40
-0
csrc/multi_tensor_apply.cuh
csrc/multi_tensor_apply.cuh
+120
-0
csrc/multi_tensor_apply.h
csrc/multi_tensor_apply.h
+0
-49
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+111
-0
examples/imagenet/main_fp16_optimizer.py
examples/imagenet/main_fp16_optimizer.py
+1
-1
setup.py
setup.py
+4
-4
tests/run_amp/test_multi_tensor_scale.py
tests/run_amp/test_multi_tensor_scale.py
+126
-0
No files found.
apex/amp/__init__.py
View file @
6763a8be
...
...
@@ -2,4 +2,3 @@ from .amp import init, half_function, float_function, promote_function,\
register_half_function
,
register_float_function
,
register_promote_function
from
.handle
import
scale_loss
from
.frontend
import
register
from
.multi_tensor_apply
import
MultiTensorApply
,
multi_tensor_applier
apex/amp/initialize.py
→
apex/amp/
_
initialize.py
View file @
6763a8be
...
...
@@ -11,14 +11,14 @@ def check_params_fp32(model):
for
name
,
param
in
model
.
named_parameters
():
if
param
.
is_floating_point
()
and
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.
register
, you do not need to call .half() on your model
\n
"
"When using amp.
initialize
, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
param
.
type
()))
for
name
,
buf
in
model
.
named_buffers
():
if
buf
.
is_floating_point
()
and
buf
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.
register
, you do not need to call .half() on your model
\n
"
"When using amp.
initialize
, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
.
format
(
name
,
buf
.
type
()))
...
...
@@ -79,7 +79,7 @@ def _initialize(models, optimizers, properties):
if
parallel_type
is
not
None
:
raise
RuntimeError
(
"Incoming model is an instance of {}. "
.
format
(
parallel_type
)
+
"Parallel wrappers should only be applied AFTER the model(s) have been "
"returned from amp.
register
."
)
"returned from amp.
initialize
."
)
for
model
in
models
:
check_params_fp32
(
model
)
...
...
apex/amp/frontend.py
View file @
6763a8be
import
torch
from
.initialize
import
_initialize
from
.
_
initialize
import
_initialize
from
._amp_state
import
_amp_state
...
...
@@ -24,7 +24,7 @@ class Properties(object):
"enable_ddp_interop"
:
False
}
"""
This function
will
allow updating several options at a time without routing through
This function allow
s
updating several options at a time without routing through
__setattr__ checks, to avoid "you can't get there from here" scenarios.
"""
def
update_options_dict
(
new_options
):
...
...
@@ -97,7 +97,7 @@ class O2:
properties
.
cast_torch_functions
=
False
properties
.
cast_batchnorm
=
torch
.
float32
properties
.
master_weights
=
True
properties
.
loss_scale
=
128.0
properties
.
loss_scale
=
"dynamic"
properties
.
flatten_model_params
=
False
properties
.
flatten_master_params
=
False
properties
.
fused_optimizer
=
False
...
...
@@ -160,20 +160,20 @@ def check_params_fp32(model):
for
name
,
param
in
model
.
named_parameters
():
if
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found param {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.
register
, you do not need to call .half() on your model
\n
"
"When using amp.
initialize
, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
,
name
,
param
.
type
())
for
name
,
param
in
model
.
named_buffers
():
if
param
.
type
()
!=
"torch.cuda.FloatTensor"
:
print
(
"Warning: Found buffer {} with type {}, expected torch.cuda.FloatTensor.
\n
"
"When using amp.
register
, you do not need to call .half() on your model
\n
"
"When using amp.
initialize
, you do not need to call .half() on your model
\n
"
"before passing it, no matter what optimization level you choose."
,
name
,
param
.
type
())
# allow user to directly pass Properties struct as well?
def
register
(
models
,
optimizers
,
enabled
=
True
,
opt_level
=
None
,
**
kwargs
):
def
initialize
(
models
,
optimizers
,
enabled
=
True
,
opt_level
=
None
,
**
kwargs
):
"""
Expected kwargs:
opt_level=None,
...
...
apex/amp/multi_tensor_apply.py
deleted
100644 → 0
View file @
889d1712
import
torch
class
MultiTensorApply
(
object
):
available
=
False
warned
=
False
def
__init__
(
self
,
max_blocks
,
max_tensors
,
max_depth
,
chunk_size
):
try
:
import
amp_C
MultiTensorApply
.
available
=
True
MultiTensorApply
.
prep_multi_tensor_launch
=
amp_C
.
prep_multi_tensor_launch
self
.
chunk_size
=
chunk_size
self
.
reallocate
(
max_blocks
,
max_tensors
,
max_depth
)
except
ImportError
as
err
:
MultiTensorApply
.
availble
=
False
MultiTensorApply
.
import_err
=
err
def
check_avail
(
self
):
if
MultiTensorApply
.
available
==
False
:
raise
RuntimeError
(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
"--cpp_ext --cuda_ext. Original import error message:"
,
MultiTensorApply
.
import_err
)
def
__call__
(
self
,
op
,
noop_flag_buffer
,
tensor_lists
,
*
args
):
self
.
check_avail
()
assert
len
(
tensor_lists
)
>
0
,
"len(tensor_lists) = {}"
.
format
(
len
(
tensor_lists
))
len0
=
len
(
tensor_lists
[
0
])
assert
len0
>
0
,
"len(tensor_lists[0]) = {}"
.
format
(
len0
)
for
i
,
l
in
enumerate
(
tensor_lists
):
assert
len
(
tensor_lists
[
i
])
==
len0
,
\
"len(tensor_lists[{}] = {}, len(tensor_lists[0] = {}"
.
format
(
len
(
tensor_lists
[
i
]),
len
(
tensor_lists
[
0
]))
self
.
assign_blocks
(
tensor_lists
)
# print(self.gpu_block_to_tensor)
# print(self.gpu_block_to_chunk)
# print(self.gpu_tensor_sizes)
return
op
(
self
.
nblocks
,
noop_flag_buffer
,
self
.
cpu_tensor_addresses
,
self
.
gpu_block_to_tensor
,
self
.
gpu_block_to_chunk
,
self
.
gpu_tensor_sizes
,
self
.
gpu_tensor_addresses
,
self
.
chunk_size
,
tensor_lists
,
*
args
)
# print()
# print([[p.data_ptr() for p in l] for l in tensor_lists])
# print()
# print(self.gpu_tensor_addresses)
def
assign_blocks
(
self
,
tensor_lists
):
self
.
check_avail
()
needs_reallocate
=
False
# Currently, this loop appears prohibitively expensive.
# Need to move to c++.
torch
.
cuda
.
nvtx
.
range_push
(
"assign_blocks loop"
)
# list0 = tensor_lists[0]
# self.nblocks = 0
# for t, tensor in enumerate(list0):
# blocks_this_tensor = (tensor.numel() +
# self.chunk_size - 1)//self.chunk_size
# if not needs_reallocate:
# self.cpu_tensor_sizes[t] = tensor.numel()
# for chunk in range(blocks_this_tensor):
# if self.nblocks >= self.max_blocks:
# needs_reallocate = True
# if not needs_reallocate:
# self.cpu_block_to_tensor[self.nblocks] = t
# self.cpu_block_to_chunk[self.nblocks] = chunk
# self.nblocks += 1
needs_reallocate
,
self
.
nblocks
=
MultiTensorApply
.
prep_multi_tensor_launch
(
self
.
cpu_block_to_tensor
,
self
.
cpu_block_to_chunk
,
self
.
cpu_tensor_sizes
,
self
.
gpu_block_to_tensor
,
self
.
gpu_block_to_chunk
,
self
.
gpu_tensor_sizes
,
self
.
chunk_size
,
self
.
max_depth
,
self
.
max_tensors
,
self
.
max_blocks
,
tensor_lists
)
torch
.
cuda
.
nvtx
.
range_pop
()
# print(self.nblocks)
if
self
.
nblocks
>
self
.
max_blocks
:
self
.
max_blocks
=
self
.
nblocks
if
len
(
tensor_lists
)
>
self
.
max_depth
:
self
.
max_depth
=
len
(
tensor_lists
)
if
len
(
tensor_lists
[
0
])
>
self
.
max_tensors
:
self
.
max_tensors
=
len
(
tensor_lists
[
0
])
if
needs_reallocate
:
self
.
reallocate
(
self
.
max_blocks
,
self
.
max_tensors
,
self
.
max_depth
)
needs_reallocate
,
self
.
nblocks
=
MultiTensorApply
.
prep_multi_tensor_launch
(
self
.
cpu_block_to_tensor
,
self
.
cpu_block_to_chunk
,
self
.
cpu_tensor_sizes
,
self
.
gpu_block_to_tensor
,
self
.
gpu_block_to_chunk
,
self
.
gpu_tensor_sizes
,
self
.
chunk_size
,
self
.
max_depth
,
self
.
max_tensors
,
self
.
max_blocks
,
tensor_lists
)
assert
needs_reallocate
==
0
,
"Should not need reallocate on second attempt."
assert
self
.
nblocks
<=
self
.
max_blocks
,
"Should not need to increase blocks again."
def
reallocate
(
self
,
max_blocks
,
max_tensors
,
max_depth
):
self
.
check_avail
()
self
.
max_blocks
=
max_blocks
self
.
max_tensors
=
max_tensors
self
.
max_depth
=
max_depth
self
.
cpu_block_to_tensor
=
torch
.
IntTensor
(
max_blocks
).
pin_memory
()
self
.
cpu_block_to_chunk
=
torch
.
IntTensor
(
max_blocks
).
pin_memory
()
self
.
cpu_tensor_sizes
=
torch
.
IntTensor
(
max_tensors
).
pin_memory
()
self
.
cpu_tensor_addresses
=
torch
.
LongTensor
(
max_depth
,
max_tensors
).
pin_memory
()
self
.
gpu_block_to_tensor
=
torch
.
cuda
.
IntTensor
(
max_blocks
)
self
.
gpu_block_to_chunk
=
torch
.
cuda
.
IntTensor
(
max_blocks
)
self
.
gpu_tensor_sizes
=
torch
.
cuda
.
IntTensor
(
max_tensors
)
self
.
gpu_tensor_addresses
=
torch
.
cuda
.
LongTensor
(
max_depth
,
max_tensors
)
multi_tensor_applier
=
MultiTensorApply
(
1000
,
100
,
4
,
2048
)
apex/amp/scaler.py
View file @
6763a8be
import
torch
import
logging
from
.multi_tensor_apply
import
multi_tensor_applier
from
.
.multi_tensor_apply
import
multi_tensor_applier
from
._amp_state
import
_amp_state
# from apex_C import scale_check_overflow
...
...
@@ -46,7 +46,7 @@ class LossScaler(object):
if
multi_tensor_applier
.
available
:
import
amp_C
LossScaler
.
has_fused_kernel
=
multi_tensor_applier
.
available
LossScaler
.
multi_tensor_
un
scale_cuda
=
amp_C
.
multi_tensor_
un
scale
LossScaler
.
multi_tensor_scale_cuda
=
amp_C
.
multi_tensor_scale
else
:
if
not
LossScaler
.
warned_no_fused_kernel
:
print
(
"Warning: multi_tensor_applier fused downscale kernel is unavailable, "
...
...
@@ -115,7 +115,7 @@ class LossScaler(object):
return
else
:
multi_tensor_applier
(
LossScaler
.
multi_tensor_
un
scale_cuda
,
LossScaler
.
multi_tensor_scale_cuda
,
self
.
_overflow_buf
,
[
model_grads
,
master_grads
],
1.
/
scale
)
...
...
apex/fp16_utils/fp16_optimizer.py
View file @
6763a8be
...
...
@@ -515,15 +515,24 @@ class FP16_Optimizer(object):
# self._downscale_master()
# Use the one-shot multi-tensor apply kernel
if
len
(
self
.
all_fp16_params
)
>
0
:
# print("Model grads before")
# print([param.grad.data for param in self.all_fp16_params])
self
.
loss_scaler
.
unscale
(
self
.
all_fp16_params
,
self
.
all_fp32_from_fp16_params
,
self
.
loss_scaler
.
loss_scale
())
# print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp16_params])
if
len
(
self
.
all_fp32_from_fp32_params
)
>
0
:
# print("Model grads before")
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
self
.
loss_scaler
.
unscale
(
self
.
all_fp32_from_fp32_params
,
self
.
all_fp32_from_fp32_params
,
self
.
loss_scaler
.
loss_scale
())
# print("Master grads after")
# print([param.grad.data for param in self.all_fp32_from_fp32_params])
# quit()
self
.
overflow
=
self
.
loss_scaler
.
update_scale
()
...
...
csrc/
scale_check_overflow
.cpp
→
csrc/
amp_C_frontend
.cpp
View file @
6763a8be
#include <torch/extension.h>
void
multi_tensor_unscale_cuda
(
int
nblocks
,
at
::
Tensor
noop_flag
,
at
::
Tensor
cpu_tensor_addresses
,
at
::
Tensor
gpu_block_to_tensor
,
at
::
Tensor
gpu_block_to_chunk
,
at
::
Tensor
gpu_tensor_sizes
,
at
::
Tensor
gpu_tensor_addresses
,
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
);
std
::
vector
<
int
>
prep_multi_tensor_launch
(
at
::
Tensor
cpu_block_to_tensor
,
at
::
Tensor
cpu_block_to_chunk
,
at
::
Tensor
cpu_tensor_sizes
,
at
::
Tensor
gpu_block_to_tensor
,
at
::
Tensor
gpu_block_to_chunk
,
at
::
Tensor
gpu_tensor_sizes
,
int
chunk_size
,
int
max_depth
,
int
max_tensors
,
int
max_blocks
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
{
int
needs_reallocate
=
0
;
if
(
tensor_lists
.
size
()
>
max_depth
||
tensor_lists
[
0
].
size
()
>
max_tensors
)
needs_reallocate
=
1
;
auto
cpu_tensor_sizes_a
=
cpu_tensor_sizes
.
accessor
<
int
,
1
>
();
auto
cpu_block_to_tensor_a
=
cpu_block_to_tensor
.
accessor
<
int
,
1
>
();
auto
cpu_block_to_chunk_a
=
cpu_block_to_chunk
.
accessor
<
int
,
1
>
();
int
nblocks
=
0
;
for
(
int
t
=
0
;
t
<
tensor_lists
[
0
].
size
();
t
++
)
{
int
blocks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
if
(
!
needs_reallocate
)
cpu_tensor_sizes_a
[
t
]
=
tensor_lists
[
0
][
t
].
numel
();
for
(
int
chunk
=
0
;
chunk
<
blocks_this_tensor
;
chunk
++
)
{
if
(
nblocks
>=
max_blocks
)
needs_reallocate
=
1
;
if
(
!
needs_reallocate
)
{
cpu_block_to_tensor_a
[
nblocks
]
=
t
;
cpu_block_to_chunk_a
[
nblocks
]
=
chunk
;
}
nblocks
++
;
}
}
if
(
!
needs_reallocate
)
{
gpu_block_to_tensor
.
copy_
(
cpu_block_to_tensor
,
1
);
gpu_block_to_chunk
.
copy_
(
cpu_block_to_chunk
,
1
);
gpu_tensor_sizes
.
copy_
(
cpu_tensor_sizes
,
1
);
}
return
std
::
vector
<
int
>
{
needs_reallocate
,
nblocks
};
}
void
scale_check_overflow_cuda
(
const
at
::
Tensor
&
grads
,
float
scale
,
const
at
::
Tensor
&
d_buf
,
const
at
::
Tensor
&
downscaled_grads
);
void
scale_check_overflow
(
at
::
Tensor
grads
,
float
scale
,
at
::
Tensor
overflow_buf
,
at
::
Tensor
downscaled_grads
)
// const at::optional<at::Tensor> downscaled_grads)
void
scale_check_overflow_cuda
(
const
at
::
Tensor
&
grads
,
float
scale
,
const
at
::
Tensor
&
d_buf
,
const
at
::
Tensor
&
downscaled_grads
);
void
scale_check_overflow
(
at
::
Tensor
grads
,
float
scale
,
at
::
Tensor
overflow_buf
,
at
::
Tensor
downscaled_grads
)
// const at::optional<at::Tensor> downscaled_grads)
{
AT_CHECK
(
grads
.
type
().
is_cuda
(),
"grads must be a CUDA tensor"
);
AT_CHECK
(
grads
.
is_contiguous
(),
"grads must be contiguous"
);
...
...
@@ -90,7 +35,6 @@ void scale_check_overflow(at::Tensor grads,
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"scale_check_overflow"
,
&
scale_check_overflow
,
"Fused overflow check + scale for FP32 tensors"
);
m
.
def
(
"prep_multi_tensor_launch"
,
&
prep_multi_tensor_launch
,
"Prepare multitensor launch"
);
m
.
def
(
"multi_tensor_unscale"
,
&
multi_tensor_unscale_cuda
,
"Fused overflow check + unscale for a list of contiguous tensors"
);
m
.
def
(
"multi_tensor_scale"
,
&
multi_tensor_scale_cuda
,
"Fused overflow check + scale for a list of contiguous tensors"
);
}
csrc/multi_tensor_apply.cuh
0 → 100644
View file @
6763a8be
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
// #include <iostream>
// This header is the one-stop shop for all your multi-tensor apply needs.
constexpr
int
depth_to_max_tensors
[
5
]
=
{
110
,
64
,
48
,
36
,
30
};
constexpr
int
depth_to_max_blocks
[
5
]
=
{
320
,
320
,
320
,
320
,
320
};
template
<
int
n
>
struct
TensorList
{
void
*
addresses
[
n
][
depth_to_max_tensors
[
n
-
1
]];
int
sizes
[
depth_to_max_tensors
[
n
-
1
]];
int
block_to_tensor
[
depth_to_max_blocks
[
n
-
1
]];
int
block_to_chunk
[
depth_to_max_blocks
[
n
-
1
]];
};
template
<
typename
T
,
typename
U
,
typename
...
ArgTypes
>
__global__
void
multi_tensor_apply_kernel
(
int
chunk_size
,
volatile
int
*
noop_flag
,
T
tl
,
U
callable
,
ArgTypes
...
args
)
// in_t** in, float** out, float scale
{
// Hand the chunk information to the user-supplied functor to process however it likes.
callable
(
chunk_size
,
noop_flag
,
tl
,
args
...);
}
template
<
int
depth
,
typename
T
,
typename
...
ArgTypes
>
void
multi_tensor_apply
(
int
block_size
,
int
chunk_size
,
const
at
::
Tensor
&
noop_flag
,
const
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>&
tensor_lists
,
T
callable
,
ArgTypes
...
args
)
{
AT_CHECK
(
tensor_lists
.
size
()
>
0
,
"tensor_lists.size() is not > 0"
);
int
len0
=
tensor_lists
[
0
].
size
();
AT_CHECK
(
len0
>
0
,
"tensor_lists[0].size() is not > 0"
);
for
(
int
l
=
0
;
l
<
tensor_lists
.
size
();
l
++
)
// No range-based for because I need indices
{
AT_CHECK
(
tensor_lists
[
l
].
size
()
==
len0
,
"Size mismatch among tensor lists"
);
for
(
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
AT_CHECK
(
tensor_lists
[
l
][
t
].
is_contiguous
(),
"A tensor was not contiguous."
);
AT_CHECK
(
tensor_lists
[
l
][
t
].
is_cuda
(),
"A tensor was not cuda."
);
AT_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Size mismatch"
);
}
}
int
ntensors
=
tensor_lists
[
0
].
size
();
TensorList
<
depth
>
tl
;
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
int
loc_block_info
=
0
;
int
loc_tensor_info
=
0
;
for
(
int
t
=
0
;
t
<
ntensors
;
t
++
)
{
tl
.
sizes
[
loc_tensor_info
]
=
tensor_lists
[
0
][
t
].
numel
();
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
loc_tensor_info
]
=
tensor_lists
[
d
][
t
].
data_ptr
();
loc_tensor_info
++
;
int
chunks_this_tensor
=
(
tensor_lists
[
0
][
t
].
numel
()
+
chunk_size
-
1
)
/
chunk_size
;
for
(
int
chunk
=
0
;
chunk
<
chunks_this_tensor
;
chunk
++
)
{
// std::cout << chunks_this_tensor << std::endl;
tl
.
block_to_tensor
[
loc_block_info
]
=
loc_tensor_info
-
1
;
tl
.
block_to_chunk
[
loc_block_info
]
=
chunk
;
loc_block_info
++
;
bool
tensors_full
=
(
loc_tensor_info
==
depth_to_max_tensors
[
depth
-
1
]
&&
chunk
==
chunks_this_tensor
-
1
);
bool
blocks_full
=
(
loc_block_info
==
depth_to_max_blocks
[
depth
-
1
]);
bool
last_chunk
=
(
t
==
ntensors
-
1
&&
chunk
==
chunks_this_tensor
-
1
);
if
(
tensors_full
||
blocks_full
||
last_chunk
)
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel
<<<
loc_block_info
,
block_size
,
0
,
stream
>>>
(
chunk_size
,
noop_flag
.
data
<
int
>
(),
tl
,
callable
,
args
...);
AT_CUDA_CHECK
(
cudaGetLastError
());
// Reset. The control flow possibilities here make my brain hurt.
loc_block_info
=
0
;
if
(
chunk
==
chunks_this_tensor
-
1
)
{
// std::cout << "Hit case 1 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
loc_tensor_info
=
0
;
}
else
{
// std::cout << "Hit case 2 " << cond1 << " " << cond2 << " " << cond3 << std::endl;
for
(
int
d
=
0
;
d
<
depth
;
d
++
)
tl
.
addresses
[
d
][
0
]
=
tl
.
addresses
[
d
][
loc_tensor_info
-
1
];
tl
.
sizes
[
0
]
=
tl
.
sizes
[
loc_tensor_info
-
1
];
loc_tensor_info
=
1
;
}
}
}
}
}
csrc/multi_tensor_apply.h
deleted
100644 → 0
View file @
889d1712
#include <ATen/ATen.h>
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <assert.h>
#include <cuda_runtime.h>
template
<
typename
T
,
typename
...
ArgTypes
>
__global__
void
multi_tensor_apply_kernel
(
volatile
int
*
noop_flag
,
int
*
block_to_tensor
,
int
*
block_to_chunk
,
// could also get this from scan
int
*
tensor_sizes
,
int
chunk_size
,
void
**
addresses
,
int
addresses_x
,
T
callable
,
ArgTypes
...
args
)
// in_t** in, float** out, float scale
{
__shared__
int
noop
;
__shared__
int
chunk_idx
;
__shared__
int
tensor_idx
;
__shared__
int
n
;
if
(
threadIdx
.
x
==
0
)
{
noop
=
*
noop_flag
;
tensor_idx
=
block_to_tensor
[
blockIdx
.
x
];
chunk_idx
=
block_to_chunk
[
blockIdx
.
x
];
n
=
tensor_sizes
[
tensor_idx
];
}
__syncthreads
();
if
(
noop
==
1
)
return
;
// Hand the chunk information to the user-supplied functor to process however it likes.
callable
(
noop_flag
,
tensor_idx
,
chunk_idx
,
chunk_size
,
n
,
addresses
,
addresses_x
,
args
...);
}
csrc/multi_tensor_
un
scale_kernel.cu
→
csrc/multi_tensor_scale_kernel.cu
View file @
6763a8be
...
...
@@ -2,33 +2,39 @@
#include <ATen/AccumulateType.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include "multi_tensor_apply.h"
#include "multi_tensor_apply.
cu
h"
#include <assert.h>
#include <cuda_runtime.h>
#define BLOCK_SIZE
256
#define BLOCK_SIZE
512
#define ILP 4
template
<
typename
in_t
>
struct
Uns
caleFunctor
struct
S
caleFunctor
{
__device__
__forceinline__
void
operator
()(
volatile
int
*
noop_flag
,
int
tensor_idx
,
int
chunk_idx
,
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
int
n
,
void
**
addresses
,
int
addresses_x
,
volatile
int
*
noop_gmem
,
TensorList
<
2
>&
tl
,
float
scale
)
{
__shared__
int
noop
;
__shared__
int
noop
_smem
;
in_t
*
in
=
(
in_t
*
)
addresses
[
tensor_idx
];
if
(
threadIdx
.
x
==
0
)
noop_smem
=
*
noop_gmem
;
__syncthreads
();
if
(
noop_smem
==
1
)
return
;
int
tensor_loc
=
tl
.
block_to_tensor
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in
+=
chunk_idx
*
chunk_size
;
float
*
out
=
(
float
*
)
addresses
[
addresses_x
+
tensor_
idx
];
float
*
out
=
(
float
*
)
tl
.
addresses
[
1
][
tensor_
loc
];
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
...
...
@@ -39,14 +45,6 @@ struct UnscaleFunctor
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
{
if
(
threadIdx
.
x
==
0
)
noop
=
*
noop_flag
;
__syncthreads
();
if
(
noop
==
1
)
break
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
...
...
@@ -56,6 +54,11 @@ struct UnscaleFunctor
incoming_vals
[
ii
]
=
static_cast
<
float
>
(
in
[
i
]);
}
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// the write loop, since writes just fire off once their LDGs arrive.
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// There is still compute ILP benefit from unrolling the loop though.
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
...
...
@@ -64,73 +67,45 @@ struct UnscaleFunctor
if
(
isfinite
(
incoming_vals
[
ii
]))
out
[
i
]
=
incoming_vals
[
ii
]
*
scale
;
else
*
noop_flag
=
1
;
// Blindly fire off a write. These will race but that's ok.
}
// This is NOT guaranteed to be seen immediately by thread 0 on the next iteration.
}
// I wonder if there's a way we can rig the short-circuiting with only one syncthreads.
}
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
};
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
}
// *noop_gmem = 1 is NOT guaranteed to be seen immediately by thread 0. I wonder if
// we can rig block-wide and grid-wide short-circuiting with only one syncthreads.
// It's possible we can just lean on the cache (no smem or syncs) and still be fast.
if
(
threadIdx
.
x
==
0
)
noop_smem
=
*
noop_gmem
;
__syncthreads
();
if
(
noop_smem
==
1
)
break
;
}
}
};
void
multi_tensor_unscale_cuda
(
int
nblocks
,
at
::
Tensor
noop_flag
,
at
::
Tensor
cpu_tensor_addresses
,
at
::
Tensor
gpu_block_to_tensor
,
at
::
Tensor
gpu_block_to_chunk
,
at
::
Tensor
gpu_tensor_sizes
,
at
::
Tensor
gpu_tensor_addresses
,
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
)
{
using
namespace
at
;
AT_CHECK
(
nblocks
>
0
,
"nblocks is not > 0"
);
int
addresses_x
=
gpu_tensor_addresses
.
size
(
1
);
// <.< >.> i don't see any cops. i'm going to access the pointers directly.
// auto addresses_a = cpu_tensor_addresses.accessor<int64_t, 2>();
// This logic could be moved to prep_multi_tensor_launch, but we might need to
// pick which kernel instantiation to launch based on the RTTI of tensor_lists,
// so we may as well accept tensor_lists and extract the pointers here.
void
**
addresses_a
=
(
void
**
)
cpu_tensor_addresses
.
data_ptr
();
int
len0
=
tensor_lists
[
0
].
size
();
for
(
unsigned
int
l
=
0
;
l
<
tensor_lists
.
size
();
l
++
)
{
AT_CHECK
(
tensor_lists
[
l
].
size
()
==
len0
,
"Lengths of tensor lists do not match."
);
for
(
unsigned
int
t
=
0
;
t
<
tensor_lists
[
l
].
size
();
t
++
)
{
AT_CHECK
(
tensor_lists
[
l
][
t
].
numel
()
==
tensor_lists
[
0
][
t
].
numel
(),
"Numel mismatch in corresponding tensors in different lists."
);
addresses_a
[
l
*
addresses_x
+
t
]
=
tensor_lists
[
l
][
t
].
data_ptr
();
// addresses_a[l][t] = (void*)tensor_lists[l][t].data<float>();
}
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
gpu_tensor_addresses
.
copy_
(
cpu_tensor_addresses
,
1
/*non_blocking*/
);
// Lock the output (downscaled) type to float.
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
0
][
0
].
type
(),
"multi_tensor_
un
scale_cuda"
,
"multi_tensor_scale_cuda"
,
[
&
]
{
// using accscalar_t = acc_type<scalar_t, true>;
multi_tensor_apply_kernel
<<<
nblocks
,
BLOCK_SIZE
,
0
,
stream
>>>
(
noop_flag
.
data
<
int
>
(),
gpu_block_to_tensor
.
data
<
int
>
(),
gpu_block_to_chunk
.
data
<
int
>
(),
gpu_tensor_sizes
.
data
<
int
>
(),
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
(
void
**
)
gpu_tensor_addresses
.
data_ptr
()
,
addresses_x
,
Uns
caleFunctor
<
scalar_t
>
(),
noop_flag
,
tensor_lists
,
S
caleFunctor
<
scalar_t
>
(),
scale
);
});
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
}
examples/imagenet/main_fp16_optimizer.py
View file @
6763a8be
...
...
@@ -139,7 +139,7 @@ def main():
model
=
model
.
cuda
()
if
args
.
fp16
:
model
=
network_to_half
(
model
)
model
=
FP16Model
(
model
)
if
args
.
distributed
:
# By default, apex.parallel.DistributedDataParallel overlaps communication with
# computation in the backward pass.
...
...
setup.py
View file @
6763a8be
...
...
@@ -38,12 +38,12 @@ if "--cuda_ext" in sys.argv:
else
:
ext_modules
.
append
(
CUDAExtension
(
name
=
'amp_C'
,
sources
=
[
'csrc/
scale_check_overflow
.cpp'
,
sources
=
[
'csrc/
amp_C_frontend
.cpp'
,
'csrc/scale_check_overflow_kernel.cu'
,
'csrc/multi_tensor_
un
scale_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
,
],
'csrc/multi_tensor_scale_kernel.cu'
],
extra_compile_args
=
{
'cxx'
:
[
'-O3'
],
'nvcc'
:[
'-lineinfo'
,
'-O3'
,
'-O3'
,
'--use_fast_math'
]}))
ext_modules
.
append
(
CUDAExtension
(
name
=
'fused_adam_cuda'
,
...
...
tests/run_amp/test_multi_tensor_scale.py
0 → 100644
View file @
6763a8be
import
unittest
import
functools
as
ft
import
itertools
as
it
from
apex
import
amp
import
torch
from
torch
import
nn
import
torch.nn.functional
as
F
from
utils
import
common_init
,
HALF
,
FLOAT
,
\
ALWAYS_HALF
,
ALWAYS_FLOAT
,
MATCH_INPUT
try
:
import
amp_C
from
amp_C
import
multi_tensor_scale
from
apex.multi_tensor_apply
import
MultiTensorApply
disabled
=
False
except
ImportError
as
err
:
print
(
"amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was "
,
err
)
disabled
=
True
class
TestMultiTensorScale
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
scale
=
4.0
self
.
overflow_buf
=
torch
.
cuda
.
IntTensor
(
1
).
zero_
()
self
.
ref
=
torch
.
cuda
.
FloatTensor
([
1.0
])
common_init
(
self
)
def
tearDown
(
self
):
pass
# The tensor creation here is written for convenience, not speed.
def
downscale
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
inplace
=
False
):
self
.
overflow_buf
.
zero_
()
a
=
torch
.
cuda
.
FloatTensor
(
sizea
).
fill_
(
self
.
scale
)
b
=
torch
.
cuda
.
FloatTensor
(
sizeb
).
fill_
(
self
.
scale
)
out_list
=
[]
for
i
in
range
(
repeat_tensors
):
out_list
+=
[
a
.
clone
(),
b
.
clone
()]
if
inplace
:
in_list
=
out_list
else
:
in_list
=
[
out
.
clone
().
to
(
in_type
)
for
out
in
out_list
]
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
assertTrue
(
all
([
torch
.
allclose
(
out
,
self
.
ref
)
for
out
in
out_list
]))
self
.
assertTrue
(
self
.
overflow_buf
.
item
()
==
0
)
def
find_inf
(
self
,
sizea
,
sizeb
,
applier
,
repeat_tensors
,
in_type
,
t
,
ind
,
val
,
inplace
=
False
):
self
.
overflow_buf
.
zero_
()
a
=
torch
.
cuda
.
FloatTensor
(
sizea
).
fill_
(
self
.
scale
)
b
=
torch
.
cuda
.
FloatTensor
(
sizeb
).
fill_
(
self
.
scale
)
out_list
=
[]
for
i
in
range
(
repeat_tensors
):
out_list
+=
[
a
.
clone
(),
b
.
clone
()]
if
inplace
:
in_list
=
out_list
else
:
in_list
=
[
out
.
clone
().
to
(
in_type
)
for
out
in
out_list
]
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
overflow_buf
.
zero_
()
in_list
[
t
][
ind
]
=
val
applier
(
multi_tensor_scale
,
self
.
overflow_buf
,
[
in_list
,
out_list
],
1.
/
self
.
scale
)
self
.
assertTrue
(
self
.
overflow_buf
.
item
())
# Currently, the fused kernel gives a hard error if you attempt to downscale
# into fp16 output, which imo is the desired behavior. Maybe someday we
# will learn otherwise.
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp16_to_fp16(self):
# self.downscale(self.fp16, self.fp16, self.fp16_ref)
#
# @unittest.skipIf(disabled, "amp_C is unavailable")
# def test_fp32_to_fp16(self):
# self.downscale(self.fp32, self.fp16, self.fp16_ref)
@
unittest
.
skipIf
(
disabled
,
"amp_C is unavailable"
)
def
test_fuzz
(
self
):
input_size_pairs
=
(
(
7777
*
77
,
555
*
555
),
(
777
,
555
),
(
555
,
2048
*
32
+
1
),
(
2048
*
32
+
1
,
555
),
(
555
,
2048
*
32
),
(
2048
*
32
,
555
),
(
33333
,
555
),
(
555
,
33333
))
appliers
=
(
MultiTensorApply
(
2048
*
32
),
MultiTensorApply
(
333
),
MultiTensorApply
(
33333
))
repeat_tensors
=
(
1
,
55
)
dtype_inplace_pairs
=
(
(
torch
.
float16
,
False
),
(
torch
.
float32
,
False
),
(
torch
.
float32
,
True
))
for
sizea
,
sizeb
in
input_size_pairs
:
for
applier
in
appliers
:
for
repeat
in
repeat_tensors
:
for
dtype
,
inplace
in
dtype_inplace_pairs
:
self
.
downscale
(
sizea
,
sizeb
,
applier
,
repeat
,
dtype
,
inplace
=
inplace
)
self
.
find_inf
(
sizea
,
sizeb
,
applier
,
repeat
,
dtype
,
0
,
0
,
float
(
'nan'
),
inplace
=
inplace
)
self
.
find_inf
(
sizea
,
sizeb
,
applier
,
repeat
,
dtype
,
2
*
repeat
-
1
,
sizeb
-
1
,
float
(
'inf'
),
inplace
=
inplace
)
self
.
find_inf
(
sizea
,
sizeb
,
applier
,
repeat
,
dtype
,
2
*
(
repeat
//
2
),
sizea
//
2
,
float
(
'inf'
),
inplace
=
inplace
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment