Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1762ba14
Commit
1762ba14
authored
Mar 31, 2022
by
Ziyue Jiang
Committed by
binmakeswell
Apr 06, 2022
Browse files
fix format (#563)
parent
ce8a3eae
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
65 deletions
+54
-65
colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
...alai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
+54
-65
No files found.
colossalai/kernel/cuda_native/csrc/multi_tensor_scale_kernel.cu
View file @
1762ba14
...
@@ -9,32 +9,29 @@
...
@@ -9,32 +9,29 @@
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
#include <sstream>
#include <sstream>
#include "type_shim.h"
#include "multi_tensor_apply.cuh"
#include "multi_tensor_apply.cuh"
#include "type_shim.h"
#define BLOCK_SIZE 512
#define BLOCK_SIZE 512
#define ILP 4
#define ILP 4
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
bool
is_aligned
(
T
*
p
)
{
__device__
__forceinline__
bool
is_aligned
(
T
*
p
){
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
return
((
uint64_t
)
p
)
%
(
ILP
*
sizeof
(
T
))
==
0
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
int
src_offset
){
__device__
__forceinline__
void
load_store
(
T
*
dst
,
T
*
src
,
int
dst_offset
,
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
int
src_offset
)
{
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
typedef
typename
std
::
aligned_storage
<
ILP
*
sizeof
(
T
),
ILP
*
alignof
(
T
)
>::
type
LT
;
((
LT
*
)
dst
)[
dst_offset
]
=
((
LT
*
)
src
)[
src_offset
];
}
}
template
<
typename
in_t
,
typename
out_t
>
template
<
typename
in_t
,
typename
out_t
>
struct
ScaleFunctor
{
struct
ScaleFunctor
__device__
__forceinline__
void
operator
()(
int
chunk_size
,
{
volatile
int
*
noop_gmem
,
__device__
__forceinline__
void
operator
()(
TensorListMetadata
<
2
>
&
tl
,
int
chunk_size
,
float
scale
)
{
volatile
int
*
noop_gmem
,
TensorListMetadata
<
2
>&
tl
,
float
scale
)
{
// I'd like this kernel to propagate infs/nans.
// I'd like this kernel to propagate infs/nans.
// if(*noop_gmem == 1)
// if(*noop_gmem == 1)
// return;
// return;
...
@@ -43,93 +40,85 @@ struct ScaleFunctor
...
@@ -43,93 +40,85 @@ struct ScaleFunctor
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
chunk_idx
=
tl
.
block_to_chunk
[
blockIdx
.
x
];
int
n
=
tl
.
sizes
[
tensor_loc
];
int
n
=
tl
.
sizes
[
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in_t
*
in
=
(
in_t
*
)
tl
.
addresses
[
0
][
tensor_loc
];
in
+=
chunk_idx
*
chunk_size
;
in
+=
chunk_idx
*
chunk_size
;
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
out_t
*
out
=
(
out_t
*
)
tl
.
addresses
[
1
][
tensor_loc
];
out
+=
chunk_idx
*
chunk_size
;
out
+=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
n
-=
chunk_idx
*
chunk_size
;
bool
finite
=
true
;
bool
finite
=
true
;
in_t
r_in
[
ILP
];
in_t
r_in
[
ILP
];
out_t
r_out
[
ILP
];
out_t
r_out
[
ILP
];
// to make things simple, we put aligned case in a different code path
// to make things simple, we put aligned case in a different code path
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
in
)
&&
is_aligned
(
out
))
if
(
n
%
ILP
==
0
&&
chunk_size
%
ILP
==
0
&&
is_aligned
(
in
)
&&
{
is_aligned
(
out
))
{
for
(
int
i_start
=
threadIdx
.
x
;
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
for
(
int
i_start
=
threadIdx
.
x
;
{
i_start
*
ILP
<
n
&&
i_start
*
ILP
<
chunk_size
;
i_start
+=
blockDim
.
x
)
{
// load
// load
load_store
(
r_in
,
in
,
0
,
i_start
);
load_store
(
r_in
,
in
,
0
,
i_start
);
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
}
}
// store
// store
load_store
(
out
,
r_out
,
i_start
,
0
);
load_store
(
out
,
r_out
,
i_start
,
0
);
}
}
}
}
else
{
else
{
// Non-divergent exit condition for __syncthreads, not necessary here
// Non-divergent exit condition for __syncthreads, not necessary here
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
i_start
+=
blockDim
.
x
*
ILP
)
for
(
int
i_start
=
0
;
i_start
<
n
&&
i_start
<
chunk_size
;
{
i_start
+=
blockDim
.
x
*
ILP
)
{
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
r_in
[
ii
]
=
0
;
r_in
[
ii
]
=
0
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
r_in
[
ii
]
=
in
[
i
];
r_in
[
ii
]
=
in
[
i
];
}
}
// note for clarification to future michael:
// note for clarification to future michael:
// From a pure memory dependency perspective, there's likely no point unrolling
// From a pure memory dependency perspective, there's likely no point
// the write loop, since writes just fire off once their LDGs arrive.
// unrolling the write loop, since writes just fire off once their LDGs
// Put another way, the STGs are dependent on the LDGs, but not on each other.
// arrive. Put another way, the STGs are dependent on the LDGs, but not
// There is still compute ILP benefit from unrolling the loop though.
// on each other. There is still compute ILP benefit from unrolling the
// loop though.
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
r_out
[
ii
]
=
static_cast
<
float
>
(
r_in
[
ii
])
*
scale
;
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
finite
=
finite
&&
isfinite
(
r_in
[
ii
]);
}
}
#pragma unroll
#pragma unroll
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
for
(
int
ii
=
0
;
ii
<
ILP
;
ii
++
)
{
{
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
int
i
=
i_start
+
threadIdx
.
x
+
ii
*
blockDim
.
x
;
if
(
i
<
n
&&
i
<
chunk_size
)
if
(
i
<
n
&&
i
<
chunk_size
)
out
[
i
]
=
r_out
[
ii
];
out
[
i
]
=
r_out
[
ii
];
}
}
}
}
}
}
if
(
!
finite
)
if
(
!
finite
)
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
*
noop_gmem
=
1
;
// Blindly fire off a write. These will race but that's ok.
}
}
};
};
void
multi_tensor_scale_cuda
(
void
multi_tensor_scale_cuda
(
int
chunk_size
,
at
::
Tensor
noop_flag
,
int
chunk_size
,
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
at
::
Tensor
noop_flag
,
float
scale
)
{
std
::
vector
<
std
::
vector
<
at
::
Tensor
>>
tensor_lists
,
float
scale
)
{
using
namespace
at
;
using
namespace
at
;
// The output (downscaled) type is always float.
// The output (downscaled) type is always float.
// If build times suffer, think about where to put this dispatch,
// If build times suffer, think about where to put this dispatch,
// and what logic should be moved out of multi_tensor_apply.
// and what logic should be moved out of multi_tensor_apply.
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
DISPATCH_FLOAT_AND_HALF
(
DISPATCH_FLOAT_AND_HALF
(
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
tensor_lists
[
0
][
0
].
scalar_type
(),
0
,
"multi_tensor_scale_cuda"
,
multi_tensor_apply
<
2
>
(
DISPATCH_FLOAT_AND_HALF
(
BLOCK_SIZE
,
tensor_lists
[
1
][
0
].
scalar_type
(),
1
,
"multi_tensor_scale_cuda"
,
chunk_size
,
multi_tensor_apply
<
2
>
(
BLOCK_SIZE
,
chunk_size
,
noop_flag
,
tensor_lists
,
noop_flag
,
ScaleFunctor
<
scalar_t_0
,
scalar_t_1
>
(),
tensor_lists
,
scale
);))
ScaleFunctor
<
scalar_t_0
,
scalar_t_1
>
(),
scale
);
))
AT_CUDA_CHECK
(
cudaGetLastError
());
AT_CUDA_CHECK
(
cudaGetLastError
());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
// AT_CUDA_CHECK(cudaDeviceSynchronize());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment