Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f8b9aaef
Commit
f8b9aaef
authored
Jul 12, 2022
by
Sze-qq
Committed by
Frank Lee
Jul 13, 2022
Browse files
[NFC] polish colossalai/kernel/cuda_native/csrc/type_shim.h code style (#1260)
parent
f660152c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
210 additions
and
254 deletions
+210
-254
colossalai/kernel/cuda_native/csrc/type_shim.h
colossalai/kernel/cuda_native/csrc/type_shim.h
+210
-254
No files found.
colossalai/kernel/cuda_native/csrc/type_shim.h
View file @
f8b9aaef
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch(TYPE) \
{ \
case at::ScalarType::Half: \
{ \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: \
{ \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#include "compat.h"
#define DISPATCH_HALF_AND_BFLOAT(TYPE, NAME, ...) \
switch (TYPE) { \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
__VA_ARGS__; \
break; \
} \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
#define DISPATCH_FLOAT_HALF_AND_BFLOAT_INOUT_TYPES(TYPEIN, TYPEOUT, NAME, ...) \
switch(TYPEIN) \
switch (TYPEIN) { \
{ \
case at::ScalarType::Float: { \
case at::ScalarType::Float: \
using scalar_t_in = float; \
{ \
switch (TYPEOUT) { \
using scalar_t_in = float; \
case at::ScalarType::Float: { \
switch(TYPEOUT) \
using scalar_t_out = float; \
{ \
__VA_ARGS__; \
case at::ScalarType::Float: \
break; \
{ \
} \
using scalar_t_out = float; \
case at::ScalarType::Half: { \
__VA_ARGS__; \
using scalar_t_out = at::Half; \
break; \
__VA_ARGS__; \
} \
break; \
case at::ScalarType::Half: \
} \
{ \
case at::ScalarType::BFloat16: { \
using scalar_t_out = at::Half; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::BFloat16: \
default: \
{ \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
using scalar_t_out = at::BFloat16; \
} \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::Half: { \
default: \
using scalar_t_in = at::Half; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEOUT), "'"); \
using scalar_t_out = at::Half; \
} \
__VA_ARGS__; \
break; \
break; \
} \
} \
case at::ScalarType::Half: \
case at::ScalarType::BFloat16: { \
{ \
using scalar_t_in = at::BFloat16; \
using scalar_t_in = at::Half; \
using scalar_t_out = at::BFloat16; \
using scalar_t_out = at::Half; \
__VA_ARGS__; \
__VA_ARGS__; \
break; \
break; \
} \
} \
default: \
case at::ScalarType::BFloat16: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
{ \
}
using scalar_t_in = at::BFloat16; \
using scalar_t_out = at::BFloat16; \
__VA_ARGS__; \
break; \
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPEIN), "'"); \
}
// Forward/backward compatiblity hack around
// Forward/backward compatiblity hack around
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
// https://github.com/pytorch/pytorch/commit/3aeb78079bcd68282fe9117088e138b77318e288
...
@@ -81,222 +68,191 @@
...
@@ -81,222 +68,191 @@
// TypeShim(const at::Type& type) : payload(type) {}
// TypeShim(const at::Type& type) : payload(type) {}
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// // Enable trivial conversion to a const at::Type& for pre-3aeb78
// operator const at::Type&(){ return payload; };
// operator const at::Type&(){ return payload; };
// // Enable dispatch switch statements to take *this directly for
post-3aeb78
// // Enable dispatch switch statements to take *this directly for post-3aeb78
// //operator at::ScalarType(){ return payload.; };
// //operator at::ScalarType(){ return payload.; };
// };
// };
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
switch (TYPE) { \
{ \
case at::ScalarType::Float: { \
case at::ScalarType::Float: \
using scalar_t_##LEVEL = float; \
{ \
__VA_ARGS__; \
using scalar_t_##LEVEL = float; \
break; \
__VA_ARGS__; \
} \
break; \
case at::ScalarType::Half: { \
} \
using scalar_t_##LEVEL = at::Half; \
case at::ScalarType::Half: \
__VA_ARGS__; \
{ \
break; \
using scalar_t_##LEVEL = at::Half; \
} \
__VA_ARGS__; \
default: \
break; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_FLOAT_HALF_AND_BYTE(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
switch (TYPE) { \
{ \
case at::ScalarType::Float: { \
case at::ScalarType::Float: \
using scalar_t_##LEVEL = float; \
{ \
__VA_ARGS__; \
using scalar_t_##LEVEL = float; \
break; \
__VA_ARGS__; \
} \
break; \
case at::ScalarType::Half: { \
} \
using scalar_t_##LEVEL = at::Half; \
case at::ScalarType::Half: \
__VA_ARGS__; \
{ \
break; \
using scalar_t_##LEVEL = at::Half; \
} \
__VA_ARGS__; \
case at::ScalarType::Byte: { \
break; \
using scalar_t_##LEVEL = uint8_t; \
} \
__VA_ARGS__; \
case at::ScalarType::Byte: \
break; \
{ \
} \
using scalar_t_##LEVEL = uint8_t; \
default: \
__VA_ARGS__; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
break; \
}
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_DOUBLE_FLOAT_AND_HALF(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
switch (TYPE) { \
{ \
case at::ScalarType::Double: { \
case at::ScalarType::Double: \
using scalar_t_##LEVEL = double; \
{ \
__VA_ARGS__; \
using scalar_t_##LEVEL = double; \
break; \
__VA_ARGS__; \
} \
break; \
case at::ScalarType::Float: { \
} \
using scalar_t_##LEVEL = float; \
case at::ScalarType::Float: \
__VA_ARGS__; \
{ \
break; \
using scalar_t_##LEVEL = float; \
} \
__VA_ARGS__; \
case at::ScalarType::Half: { \
break; \
using scalar_t_##LEVEL = at::Half; \
} \
__VA_ARGS__; \
case at::ScalarType::Half: \
break; \
{ \
} \
using scalar_t_##LEVEL = at::Half; \
default: \
__VA_ARGS__; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
break; \
}
} \
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
#define DISPATCH_DOUBLE_AND_FLOAT(TYPE, LEVEL, NAME, ...) \
switch (TYPE) \
switch (TYPE) { \
{ \
case at::ScalarType::Double: { \
case at::ScalarType::Double: \
using scalar_t_##LEVEL = double; \
{ \
__VA_ARGS__; \
using scalar_t_##LEVEL = double; \
break; \
__VA_ARGS__; \
} \
break; \
case at::ScalarType::Float: { \
} \
using scalar_t_##LEVEL = float; \
case at::ScalarType::Float: \
__VA_ARGS__; \
{ \
break; \
using scalar_t_##LEVEL = float; \
} \
__VA_ARGS__; \
default: \
break; \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
} \
}
default: \
AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'"); \
}
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
#define DISPATCH_FLOAT_AND_HALF_FOR_G_P(GTYPE, PTYPE, LEVEL, NAME, ...) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) \
if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Float) { \
{ \
using g_scalar_t_##LEVEL = float; \
using g_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
using p_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
__VA_ARGS__; \
} else if (GTYPE == at::ScalarType::Float && \
} \
PTYPE == at::ScalarType::Half) { \
else if (GTYPE == at::ScalarType::Float && PTYPE == at::ScalarType::Half) \
using g_scalar_t_##LEVEL = float; \
{ \
using p_scalar_t_##LEVEL = at::Half; \
using g_scalar_t_##LEVEL = float; \
__VA_ARGS__; \
using p_scalar_t_##LEVEL = at::Half; \
} else if (GTYPE == at::ScalarType::Half && \
__VA_ARGS__; \
PTYPE == at::ScalarType::Float) { \
} \
using g_scalar_t_##LEVEL = at::Half; \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Float) \
using p_scalar_t_##LEVEL = float; \
{ \
__VA_ARGS__; \
using g_scalar_t_##LEVEL = at::Half; \
} else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) { \
using p_scalar_t_##LEVEL = float; \
using g_scalar_t_##LEVEL = at::Half; \
__VA_ARGS__; \
using p_scalar_t_##LEVEL = at::Half; \
} \
__VA_ARGS__; \
else if (GTYPE == at::ScalarType::Half && PTYPE == at::ScalarType::Half) \
} else { \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), \
using g_scalar_t_##LEVEL = at::Half; \
"'"); \
using p_scalar_t_##LEVEL = at::Half; \
}
__VA_ARGS__; \
} \
else \
{ \
AT_ERROR(#NAME, "not implemented for '", toString(GTYPE), toString(PTYPE), "'"); \
} \
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
*
x
,
__device__
__forceinline__
T
reduce_block_into_lanes
(
T
val
,
T
*
x
,
T
val
,
int
lanes
=
1
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
if
(
blockSize
>=
64
)
{
{
x
[
tid
]
=
val
;
x
[
tid
]
=
val
;
__syncthreads
();
__syncthreads
();
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
{
if
(
tid
<
i
)
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
if
(
tid
<
i
)
__syncthreads
();
x
[
tid
]
=
x
[
tid
]
+
x
[
tid
+
i
];
}
__syncthreads
();
}
T
final
;
T
final
;
if
(
tid
<
32
)
if
(
tid
<
32
)
{
{
if
(
blockSize
>=
64
)
if
(
blockSize
>=
64
)
final
=
x
[
tid
]
+
x
[
tid
+
32
];
final
=
x
[
tid
]
+
x
[
tid
+
32
];
else
else
final
=
val
;
final
=
val
;
// __SYNCWARP();
// __SYNCWARP();
#pragma unroll
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
final
=
final
+
__shfl_down_sync
(
0xffffffff
,
final
,
i
);
}
}
if
(
share_result
)
if
(
share_result
)
{
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
if
(
tid
<
lanes
)
// Make sure the smem result is visible to all warps.
x
[
tid
]
=
final
;
// EpilogueOp
__syncthreads
();
// Make sure the smem result is visible to all warps.
}
__syncthreads
();
}
return
final
;
return
final
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
*
x
,
__device__
__forceinline__
T
reduce_block_into_lanes_max_op
(
T
val
,
T
*
x
,
T
val
,
int
lanes
=
1
,
int
lanes
=
1
,
bool
share_result
=
false
)
// lanes is intended to be <= 32.
bool
share_result
=
false
)
// lanes is intended to be <= 32.
{
{
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
threadIdx
.
y
*
blockDim
.
x
;
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
int
blockSize
=
blockDim
.
x
*
blockDim
.
y
;
// blockSize is intended to be a multiple of 32.
if
(
blockSize
>=
64
)
if
(
blockSize
>=
64
)
{
{
x
[
tid
]
=
val
;
x
[
tid
]
=
val
;
__syncthreads
();
__syncthreads
();
}
}
#pragma unroll
#pragma unroll
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
for
(
int
i
=
(
blockSize
>>
1
);
i
>=
64
;
i
>>=
1
)
{
{
if
(
tid
<
i
)
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
if
(
tid
<
i
)
__syncthreads
();
x
[
tid
]
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
i
]));
}
__syncthreads
();
}
T
final
;
T
final
;
if
(
tid
<
32
)
if
(
tid
<
32
)
{
{
if
(
blockSize
>=
64
)
if
(
blockSize
>=
64
)
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
final
=
fmaxf
(
fabsf
(
x
[
tid
]),
fabsf
(
x
[
tid
+
32
]));
else
else
final
=
val
;
final
=
val
;
// __SYNCWARP();
// __SYNCWARP();
#pragma unroll
#pragma unroll
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
for
(
int
i
=
16
;
i
>=
lanes
;
i
>>=
1
)
final
=
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
final
=
}
fmaxf
(
fabsf
(
final
),
fabsf
(
__shfl_down_sync
(
0xffffffff
,
final
,
i
)));
}
if
(
share_result
)
if
(
share_result
)
{
{
if
(
tid
<
lanes
)
x
[
tid
]
=
final
;
// EpilogueOp
if
(
tid
<
lanes
)
// Make sure the smem result is visible to all warps.
x
[
tid
]
=
final
;
// EpilogueOp
__syncthreads
();
// Make sure the smem result is visible to all warps.
}
__syncthreads
();
}
return
final
;
return
final
;
}
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment