Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
d900e93c
Commit
d900e93c
authored
Apr 26, 2019
by
Michael Carilli
Browse files
Merging in master
parents
c978bda5
855808f3
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
235 additions
and
213 deletions
+235
-213
apex/amp/_initialize.py
apex/amp/_initialize.py
+3
-3
csrc/fused_adam_cuda_kernel.cu
csrc/fused_adam_cuda_kernel.cu
+15
-15
csrc/layer_norm_cuda.cpp
csrc/layer_norm_cuda.cpp
+2
-2
csrc/layer_norm_cuda_kernel.cu
csrc/layer_norm_cuda_kernel.cu
+17
-17
csrc/multi_tensor_l2norm_kernel.cu
csrc/multi_tensor_l2norm_kernel.cu
+1
-1
csrc/multi_tensor_scale_kernel.cu
csrc/multi_tensor_scale_kernel.cu
+9
-33
csrc/type_shim.h
csrc/type_shim.h
+55
-9
csrc/welford.cu
csrc/welford.cu
+133
-133
No files found.
apex/amp/_initialize.py
View file @
d900e93c
...
@@ -111,7 +111,7 @@ def check_optimizers(optimizers):
...
@@ -111,7 +111,7 @@ def check_optimizers(optimizers):
if
isinstance
(
optim
,
FP16_Optimizer_for_fused
):
if
isinstance
(
optim
,
FP16_Optimizer_for_fused
):
bad_optim_type
=
"apex.optimizers.FP16_Optimizer"
bad_optim_type
=
"apex.optimizers.FP16_Optimizer"
if
bad_optim_type
is
not
None
:
if
bad_optim_type
is
not
None
:
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
optim_type
)
+
raise
RuntimeError
(
"An incoming optimizer is an instance of {}. "
.
format
(
bad_
optim_type
)
+
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"The optimizer(s) passed to amp.initialize() must be bare
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"instances of either ordinary Pytorch optimizers, or Apex fused
\n
"
"optimizers (FusedAdam or FusedSGD).
\n
"
"optimizers (FusedAdam or FusedSGD).
\n
"
...
@@ -132,7 +132,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -132,7 +132,9 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
optimizers
=
[]
optimizers
=
[]
elif
isinstance
(
optimizers
,
list
):
elif
isinstance
(
optimizers
,
list
):
optimizers_was_list
=
True
optimizers_was_list
=
True
check_optimizers
(
optimizers
)
else
:
else
:
check_optimizers
([
optimizers
])
raise
TypeError
(
"optimizers must be either a single optimizer or a list of optimizers."
)
raise
TypeError
(
"optimizers must be either a single optimizer or a list of optimizers."
)
if
isinstance
(
models
,
torch
.
nn
.
Module
):
if
isinstance
(
models
,
torch
.
nn
.
Module
):
...
@@ -148,8 +150,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
...
@@ -148,8 +150,6 @@ def _initialize(models, optimizers, properties, num_losses=1, cast_model_outputs
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
if
not
_amp_state
.
allow_incoming_model_not_fp32
:
check_params_fp32
(
models
)
check_params_fp32
(
models
)
check_optimizers
(
optimizers
)
# In the future, when FP16_Optimizer can be deprecated and master weights can
# In the future, when FP16_Optimizer can be deprecated and master weights can
# become an attribute, remember to stash master weights before casting the model.
# become an attribute, remember to stash master weights before casting the model.
...
...
csrc/fused_adam_cuda_kernel.cu
View file @
d900e93c
...
@@ -182,19 +182,19 @@ void fused_adam_cuda(
...
@@ -182,19 +182,19 @@ void fused_adam_cuda(
}
}
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
g
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
)
{
if
(
g
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
)
{
//all other values should be fp32 for half gradients
//all other values should be fp32 for half gradients
AT_ASSERTM
(
p
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
AT_ASSERTM
(
p
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
,
"expected parameter to be of float type"
);
//dispatch is done on the gradient type
//dispatch is done on the gradient type
using
namespace
at
;
// prevents "toString is undefined" errors
using
namespace
at
;
// prevents "toString is undefined" errors
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
g
.
scalar_
type
(),
0
,
"adam_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
adam_cuda_kernel
<
accscalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
accscalar_t
,
scalar_t
_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data
<
accscalar_t
>
(),
p
.
data
<
accscalar_t
>
(),
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
>
()
:
NULL
,
p_copy
.
numel
()
?
p_copy
.
data
<
scalar_t
_0
>
()
:
NULL
,
m
.
data
<
accscalar_t
>
(),
m
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
v
.
data
<
accscalar_t
>
(),
g
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -203,16 +203,16 @@ void fused_adam_cuda(
...
@@ -203,16 +203,16 @@ void fused_adam_cuda(
tsize
,
tsize
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
}));
)
}
else
{
}
else
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_
FLOATING_TYPES
(
g
.
type
(),
"adam_cuda_kernel"
,
([
&
]
{
DISPATCH_
DOUBLE_AND_FLOAT
(
g
.
scalar_
type
(),
0
,
"adam_cuda_kernel"
,
adam_cuda_kernel
<
scalar_t
,
scalar_t
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
adam_cuda_kernel
<
scalar_t
_0
,
scalar_t
_0
><<<
blocks
,
threadsPerBlock
,
0
,
stream
>>>
(
p
.
data
<
scalar_t
>
(),
p
.
data
<
scalar_t
_0
>
(),
NULL
,
//don't output p_copy for fp32, it's wasted write
NULL
,
//don't output p_copy for fp32, it's wasted write
m
.
data
<
scalar_t
>
(),
m
.
data
<
scalar_t
_0
>
(),
v
.
data
<
scalar_t
>
(),
v
.
data
<
scalar_t
_0
>
(),
g
.
data
<
scalar_t
>
(),
g
.
data
<
scalar_t
_0
>
(),
beta1
,
beta1
,
beta2
,
beta2
,
eps
,
eps
,
...
@@ -221,7 +221,7 @@ void fused_adam_cuda(
...
@@ -221,7 +221,7 @@ void fused_adam_cuda(
tsize
,
tsize
,
(
adamMode_t
)
mode
,
(
adamMode_t
)
mode
,
decay
);
decay
);
})
);
);
}
}
THCudaCheck
(
cudaGetLastError
());
THCudaCheck
(
cudaGetLastError
());
...
...
csrc/layer_norm_cuda.cpp
View file @
d900e93c
...
@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
...
@@ -129,7 +129,7 @@ std::vector<at::Tensor> layer_norm(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
type
().
scalar
T
ype
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar
_t
ype
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
NULL
,
NULL
,
epsilon
);
normalized_shape
,
NULL
,
NULL
,
epsilon
);
...
@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
...
@@ -151,7 +151,7 @@ std::vector<at::Tensor> layer_norm_affine(
int
n1
,
n2
;
int
n1
,
n2
;
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
check_args
(
input
,
normalized_shape
,
gamma
,
beta
,
n1
,
n2
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
output
=
at
::
empty_like
(
input
);
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
type
().
scalar
T
ype
()));
at
::
Tensor
mean
=
at
::
empty
({
n1
},
input
.
options
().
dtype
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
scalar
_t
ype
()));
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
at
::
Tensor
invvar
=
at
::
empty_like
(
mean
);
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
cuda_layer_norm
(
&
output
,
&
mean
,
&
invvar
,
&
input
,
n1
,
n2
,
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
normalized_shape
,
&
gamma
,
&
beta
,
epsilon
);
...
...
csrc/layer_norm_cuda_kernel.cu
View file @
d900e93c
...
@@ -685,18 +685,18 @@ void cuda_layer_norm(
...
@@ -685,18 +685,18 @@ void cuda_layer_norm(
double
epsilon
)
double
epsilon
)
{
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_
FLOATING_TYPES
_AND_HALF
(
input
->
type
(),
"layer_norm_cuda_kernel"
,
([
&
]
{
DISPATCH_
DOUBLE_FLOAT
_AND_HALF
(
input
->
scalar_
type
(),
0
,
"layer_norm_cuda_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
HostApplyLayerNorm
(
HostApplyLayerNorm
(
output
->
data
<
scalar_t
>
(),
output
->
data
<
scalar_t
_0
>
(),
mean
->
data
<
accscalar_t
>
(),
mean
->
data
<
accscalar_t
>
(),
invvar
->
data
<
accscalar_t
>
(),
invvar
->
data
<
accscalar_t
>
(),
input
->
data
<
scalar_t
>
(),
input
->
data
<
scalar_t
_0
>
(),
n1
,
n2
,
n1
,
n2
,
epsilon
,
epsilon
,
gamma
!=
NULL
?
gamma
->
data
<
scalar_t
>
()
:
NULL
,
gamma
!=
NULL
?
gamma
->
data
<
scalar_t
_0
>
()
:
NULL
,
beta
!=
NULL
?
beta
->
data
<
scalar_t
>
()
:
NULL
);
beta
!=
NULL
?
beta
->
data
<
scalar_t
_0
>
()
:
NULL
);
}));
)
}
}
template
<
typename
T
,
typename
U
>
template
<
typename
T
,
typename
U
>
...
@@ -725,7 +725,7 @@ void HostLayerNormGradient(
...
@@ -725,7 +725,7 @@ void HostLayerNormGradient(
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_a
=
2
*
sizeof
(
U
)
*
threads2
.
y
*
threads2
.
y
*
(
threads2
.
x
+
1
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2_b
=
threads2
.
x
*
threads2
.
y
*
sizeof
(
U
);
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
const
int
nshared2
=
nshared2_a
>
nshared2_b
?
nshared2_a
:
nshared2_b
;
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
type
().
scalar
T
ype
()));
at
::
Tensor
part_grad_gamma
=
at
::
empty
({
part_size
,
n2
},
input
->
options
().
dtype
(
input
->
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
->
scalar
_t
ype
()));
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
at
::
Tensor
part_grad_beta
=
at
::
empty_like
(
part_grad_gamma
);
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
cuComputePartGradGammaBeta
<<<
blocks2
,
threads2
,
nshared2
,
stream
>>>
(
dout
,
dout
,
...
@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
...
@@ -787,19 +787,19 @@ void cuda_layer_norm_gradient(
at
::
Tensor
*
grad_beta
)
at
::
Tensor
*
grad_beta
)
{
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
->
type
(),
"cuComputeGradInput"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
->
scalar_
type
(),
0
,
"cuComputeGradInput"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
HostLayerNormGradient
(
HostLayerNormGradient
(
dout
->
data
<
scalar_t
>
(),
dout
->
data
<
scalar_t
_0
>
(),
mean
->
data
<
accscalar_t
>
(),
mean
->
data
<
accscalar_t
>
(),
invvar
->
data
<
accscalar_t
>
(),
invvar
->
data
<
accscalar_t
>
(),
input
,
input
,
n1
,
n2
,
n1
,
n2
,
gamma
->
data
<
scalar_t
>
(),
gamma
->
data
<
scalar_t
_0
>
(),
beta
->
data
<
scalar_t
>
(),
beta
->
data
<
scalar_t
_0
>
(),
epsilon
,
epsilon
,
grad_input
->
data
<
scalar_t
>
(),
grad_input
->
data
<
scalar_t
_0
>
(),
grad_gamma
->
data
<
scalar_t
>
(),
grad_gamma
->
data
<
scalar_t
_0
>
(),
grad_beta
->
data
<
scalar_t
>
());
grad_beta
->
data
<
scalar_t
_0
>
());
}));
)
}
}
csrc/multi_tensor_l2norm_kernel.cu
View file @
d900e93c
...
@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
...
@@ -75,7 +75,7 @@ at::Tensor multi_tensor_l2norm_cuda(
at
::
Tensor
noop_flag
,
at
::
Tensor
noop_flag
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
)
{
{
auto
output
=
at
::
zeros
({
320
},
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
ScalarType
::
Float
));
auto
output
=
at
::
zeros
({
320
},
tensor_lists
[
0
][
0
].
options
().
dtype
(
at
::
k
Float
));
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_l2norm_cuda"
,
multi_tensor_apply
<
1
>
(
multi_tensor_apply
<
1
>
(
...
...
csrc/multi_tensor_scale_kernel.cu
View file @
d900e93c
...
@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
...
@@ -86,39 +86,15 @@ void multi_tensor_scale_cuda(
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
AT_DISPATCH_FLOATING_TYPES_AND_HALF
(
tensor_lists
[
0
][
0
].
type
(),
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
[
&
]
multi_tensor_apply
<
2
>
(
{
BLOCK_SIZE
,
// using accscalar_t = acc_type<scalar_t, true>;
chunk_size
,
switch
(
tensor_lists
[
1
][
0
].
scalar_type
())
noop_flag
,
{
tensor_lists
,
case
at
::
ScalarType
::
Half
:
ScaleFunctor
<
scalar_t_0
,
scalar_t_1
>
(),
multi_tensor_apply
<
2
>
(
scale
);
))
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
ScaleFunctor
<
scalar_t
,
at
::
Half
>
(),
scale
);
break
;
case
at
::
ScalarType
::
Float
:
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
ScaleFunctor
<
scalar_t
,
float
>
(),
scale
);
break
;
default:
std
::
stringstream
ss
;
ss
<<
"multi_tensor_scale_cuda not implemented for output type = "
<<
tensor_lists
[
1
][
0
].
dtype
();
AT_ERROR
(
ss
.
str
().
c_str
());
}
});
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
...
...
csrc/type_shim.h
View file @
d900e93c
...
@@ -3,15 +3,15 @@
...
@@ -3,15 +3,15 @@
// Forward/backward compatiblity hack around
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// pending more future-proof guidance from upstream.
// pending more future-proof guidance from upstream.
struct
TypeShim
//
struct TypeShim
{
//
{
const
at
::
Type
&
payload
;
//
const at::Type& payload;
TypeShim
(
const
at
::
Type
&
type
)
:
payload
(
type
)
{}
//
TypeShim(const at::Type& type) : payload(type) {}
// Enable trivial conversion to a const at::Type& for pre-3aeb78
//
// Enable trivial conversion to a const at::Type& for pre-3aeb78
operator
const
at
::
Type
&
(){
return
payload
;
};
//
operator const at::Type&(){ return payload; };
// Enable dispatch switch statements to take *this directly for post-3aeb78
//
// Enable dispatch switch statements to take *this directly for post-3aeb78
operator
at
::
ScalarType
(){
return
payload
.
scalarType
()
;
};
// //
operator at::ScalarType(){ return payload.; };
};
//
};
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
switch(TYPE) \
...
@@ -33,6 +33,52 @@ struct TypeShim
...
@@ -33,6 +33,52 @@ struct TypeShim
}
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Half: \
{ \
using scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Double: \
{ \
using scalar_t_##LEVEL = double; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::Float: \
{ \
using scalar_t_##LEVEL = float; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
(
T
*
x
,
...
...
csrc/welford.cu
View file @
d900e93c
...
@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input)
...
@@ -182,14 +182,14 @@ __host__ int get_tensor_spatial_size(const at::Tensor& input)
// promote accumulation scalar type. promote half to float.
// promote accumulation scalar type. promote half to float.
__host__
at
::
ScalarType
promote_scalartype
(
const
at
::
Tensor
&
input
)
__host__
at
::
ScalarType
promote_scalartype
(
const
at
::
Tensor
&
input
)
{
{
return
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
?
return
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
?
at
::
ScalarType
::
Float
:
input
.
type
().
scalar
T
ype
();
at
::
ScalarType
::
Float
:
input
.
scalar
_t
ype
();
}
}
// return single element size, optional accumulation type promotion.
// return single element size, optional accumulation type promotion.
__host__
size_t
get_element_data_size
(
const
at
::
Tensor
&
input
,
bool
accumulation
=
false
)
__host__
size_t
get_element_data_size
(
const
at
::
Tensor
&
input
,
bool
accumulation
=
false
)
{
{
auto
scalar_type
=
accumulation
?
promote_scalartype
(
input
)
:
input
.
type
().
scalar
T
ype
();
auto
scalar_type
=
accumulation
?
promote_scalartype
(
input
)
:
input
.
scalar
_t
ype
();
return
at
::
elementSize
(
scalar_type
);
return
at
::
elementSize
(
scalar_type
);
}
}
...
@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
...
@@ -846,16 +846,16 @@ std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input) {
{
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"welford_mean_var_kernel"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"welford_mean_var_kernel"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
welford_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
welford_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
batch_size
,
batch_size
,
feature_size
,
feature_size
,
space_size
);
space_size
);
})
);
);
}
}
return
{
out_mean
,
out_var_biased
};
return
{
out_mean
,
out_var_biased
};
...
@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA(
...
@@ -881,40 +881,40 @@ at::Tensor batchnorm_forward_CUDA(
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
const
dim3
grid
(
feature_size
,
batch_group_size
,
grid_z
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
_0
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_forward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_forward_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
_0
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
})
);
);
}
}
return
out
;
return
out
;
}
}
...
@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -952,15 +952,15 @@ std::vector<at::Tensor> reduce_bn_CUDA(
const
dim3
grid
(
feature_size
);
const
dim3
grid
(
feature_size
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward_reduce"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
...
@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA(
...
@@ -970,28 +970,28 @@ std::vector<at::Tensor> reduce_bn_CUDA(
batch_size
,
batch_size
,
feature_size
,
feature_size
,
space_size
);
space_size
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward_reduce"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
reduce_bn_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
reduce_bn_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
><<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
_0
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
_0
>
()
:
NULL
,
batch_size
,
batch_size
,
feature_size
,
feature_size
,
space_size
);
space_size
);
})
);
);
}
}
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
...
@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA(
...
@@ -1021,44 +1021,44 @@ at::Tensor batchnorm_backward_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
_0
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_backward_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
batchnorm_backward_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
><<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
_0
>
(),
space_size
,
space_size
,
batch_size
);
batch_size
);
})
);
);
}
}
return
grad_input
;
return
grad_input
;
...
@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
...
@@ -1083,18 +1083,18 @@ std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_node
{
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
mean_feature_nodes
.
type
(),
"welford_parallel_kernel"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
mean_feature_nodes
.
scalar_
type
(),
0
,
"welford_parallel_kernel"
,
welford_kernel_parallel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
welford_kernel_parallel
<
scalar_t
_0
><<<
grid
,
block
,
0
,
stream
>>>
(
mean_feature_nodes
.
data
<
scalar_t
>
(),
mean_feature_nodes
.
data
<
scalar_t
_0
>
(),
var_biased
.
data
<
scalar_t
>
(),
var_biased
.
data
<
scalar_t
_0
>
(),
out_mean
.
data
<
scalar_t
>
(),
out_mean
.
data
<
scalar_t
_0
>
(),
out_var
.
data
<
scalar_t
>
(),
out_var
.
data
<
scalar_t
_0
>
(),
inv_std
.
data
<
scalar_t
>
(),
inv_std
.
data
<
scalar_t
_0
>
(),
world_size
,
world_size
,
feature_size
,
feature_size
,
eps
,
eps
,
numel
);
numel
);
})
);
);
}
}
return
{
out_mean
,
out_var
,
inv_std
};
return
{
out_mean
,
out_var
,
inv_std
};
...
@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
...
@@ -1118,27 +1118,27 @@ std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input) {
at
::
Tensor
semaphores
;
at
::
Tensor
semaphores
;
if
(
grid
.
y
>
1
)
{
if
(
grid
.
y
>
1
)
{
staging_data
=
at
::
empty
({
4
*
stride
*
grid
.
y
},
option
);
staging_data
=
at
::
empty
({
4
*
stride
*
grid
.
y
},
option
);
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Int
));
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
k
Int
));
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
{
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"welford_mean_var_c_last"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"welford_mean_var_c_last"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
welford_kernel_c_last
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
welford_kernel_c_last
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_mean
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
out_var_biased
.
data
<
accscalar_t
>
(),
staging_data_ptr
,
staging_data_ptr
,
semaphores_ptr
,
semaphores_ptr
,
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
}
return
{
out_mean
,
out_var_biased
};
return
{
out_mean
,
out_var_biased
};
...
@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA(
...
@@ -1161,41 +1161,41 @@ at::Tensor batchnorm_forward_c_last_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
_0
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_forward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_forward_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
>
()
:
NULL
,
shift
.
has_value
()
?
shift
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
out
.
data
<
scalar_t
>
(),
out
.
data
<
scalar_t
_0
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
}
return
out
;
return
out
;
}
}
...
@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1231,22 +1231,22 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
at
::
Tensor
semaphores
;
at
::
Tensor
semaphores
;
if
(
grid
.
y
>
1
)
{
if
(
grid
.
y
>
1
)
{
staging_data
=
at
::
empty
({
2
*
stride
*
grid
.
y
},
mean
.
options
());
staging_data
=
at
::
empty
({
2
*
stride
*
grid
.
y
},
mean
.
options
());
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
ScalarType
::
Int
));
semaphores
=
at
::
zeros
({
grid
.
x
},
input
.
options
().
dtype
(
at
::
k
Int
));
}
}
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward_reduce"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
reduce_bn_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
reduce_bn_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
...
@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
...
@@ -1257,32 +1257,32 @@ std::vector<at::Tensor> reduce_bn_c_last_CUDA(
semaphores_ptr
,
semaphores_ptr
,
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_backward_reduce"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_backward_reduce"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
accscalar_t
*
staging_data_ptr
=
grid
.
y
>
1
?
staging_data
.
data
<
accscalar_t
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
int
*
semaphores_ptr
=
grid
.
y
>
1
?
semaphores
.
data
<
int
>
()
:
nullptr
;
reduce_bn_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
reduce_bn_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_weight
.
data
<
scalar_t
_0
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
grad_bias
.
data
<
scalar_t
_0
>
()
:
NULL
,
staging_data_ptr
,
staging_data_ptr
,
semaphores_ptr
,
semaphores_ptr
,
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
}
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
return
{
mean_dy
,
mean_dy_xmu
,
grad_weight
,
grad_bias
};
...
@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA(
...
@@ -1307,45 +1307,45 @@ at::Tensor batchnorm_backward_c_last_CUDA(
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
input
.
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Half
if
(
input
.
scalar
_t
ype
()
==
at
::
ScalarType
::
Half
&&
weight
.
has_value
()
&&
weight
.
value
().
type
().
scalar
T
ype
()
==
at
::
ScalarType
::
Float
)
{
&&
weight
.
has_value
()
&&
weight
.
value
().
scalar
_t
ype
()
==
at
::
ScalarType
::
Float
)
{
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
accscalar_t
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
accscalar_t
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
_0
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
else
{
}
else
{
if
(
weight
.
has_value
())
{
if
(
weight
.
has_value
())
{
AT_CHECK
(
input
.
type
().
scalar
T
ype
()
==
weight
.
value
().
type
().
scalar
T
ype
(),
AT_CHECK
(
input
.
scalar
_t
ype
()
==
weight
.
value
().
scalar
_t
ype
(),
"input.
type().
scalar
T
ype() is not supported with weight.
type().
scalar
T
ype()"
);
"input.scalar
_t
ype() is not supported with weight.scalar
_t
ype()"
);
}
}
using
namespace
at
;
using
namespace
at
;
AT_
DISPATCH_FLOAT
ING_TYPES
_AND_HALF
(
input
.
type
(),
"batchnorm_forward"
,
([
&
]
{
DISPATCH_FLOAT_AND_HALF
(
input
.
scalar_
type
(),
0
,
"batchnorm_forward"
,
using
accscalar_t
=
at
::
acc_type
<
scalar_t
,
true
>
;
using
accscalar_t
=
at
::
acc_type
<
scalar_t
_0
,
true
>
;
batchnorm_backward_c_last_kernel
<
scalar_t
,
accscalar_t
,
scalar_t
,
ELEMENTS_PER_ITER
>
batchnorm_backward_c_last_kernel
<
scalar_t
_0
,
accscalar_t
,
scalar_t
_0
,
ELEMENTS_PER_ITER
>
<<<
grid
,
block
,
0
,
stream
>>>
(
<<<
grid
,
block
,
0
,
stream
>>>
(
grad_output
.
data
<
scalar_t
>
(),
grad_output
.
data
<
scalar_t
_0
>
(),
input
.
data
<
scalar_t
>
(),
input
.
data
<
scalar_t
_0
>
(),
mean
.
data
<
accscalar_t
>
(),
mean
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
inv_std
.
data
<
accscalar_t
>
(),
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
>
()
:
NULL
,
weight
.
has_value
()
?
weight
.
value
().
data
<
scalar_t
_0
>
()
:
NULL
,
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
mean_dy_xmu
.
data
<
accscalar_t
>
(),
grad_input
.
data
<
scalar_t
>
(),
grad_input
.
data
<
scalar_t
_0
>
(),
reduction_size
,
reduction_size
,
stride
);
stride
);
})
);
);
}
}
return
grad_input
;
return
grad_input
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment