Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b5de0b4e
Commit
b5de0b4e
authored
Feb 04, 2026
by
zhuwenwen
Browse files
support torch2.9
parent
70506d98
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
189 additions
and
724 deletions
+189
-724
CMakeLists.txt
CMakeLists.txt
+0
-1
csrc/core/batch_invariant.hpp
csrc/core/batch_invariant.hpp
+20
-0
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+68
-0
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+90
-271
csrc/ops.h
csrc/ops.h
+0
-6
csrc/opt/layernorm_kernels_opt.cu
csrc/opt/layernorm_kernels_opt.cu
+0
-413
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-13
requirements/rocm.txt
requirements/rocm.txt
+4
-4
setup.py
setup.py
+3
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-11
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+4
-2
No files found.
CMakeLists.txt
View file @
b5de0b4e
...
...
@@ -260,7 +260,6 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/opt/transpose_kernels.cu"
"csrc/opt/activation_kernels_opt.cu"
"csrc/opt/layernorm_kernels_opt.cu"
# "csrc/layernorm_quant_kernels.cu"
"csrc/sampler.cu"
"csrc/cuda_view.cu"
...
...
csrc/core/batch_invariant.hpp
0 → 100644
View file @
b5de0b4e
#pragma once
#include <cstdlib>
#include <string>
#include <cctype>
namespace
vllm
{
// vllm_is_batch_invariant(); returns true
// if env VLLM_BATCH_INVARIANT=1
inline
bool
vllm_is_batch_invariant
()
{
static
bool
cached
=
[]()
{
std
::
string
env_key
=
"VLLM_BATCH_INVARIANT"
;
const
char
*
val
=
std
::
getenv
(
env_key
.
c_str
());
return
(
val
&&
std
::
atoi
(
val
)
!=
0
)
?
1
:
0
;
}();
return
cached
;
}
}
// namespace vllm
csrc/dispatch_utils.h
View file @
b5de0b4e
...
...
@@ -88,3 +88,71 @@
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_VEC_SIZE(VEC_SIZE, ...) \
switch (VEC_SIZE) { \
case 16: { \
constexpr int vec_size = 16; \
__VA_ARGS__(); \
break; \
} \
case 8: { \
constexpr int vec_size = 8; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int vec_size = 4; \
__VA_ARGS__(); \
break; \
} \
case 2: { \
constexpr int vec_size = 2; \
__VA_ARGS__(); \
break; \
} \
default: { \
constexpr int vec_size = 1; \
__VA_ARGS__(); \
break; \
} \
}
#define VLLM_DISPATCH_BOOL(expr, const_expr, ...) \
if (expr) { \
constexpr bool const_expr = true; \
__VA_ARGS__(); \
} else { \
constexpr bool const_expr = false; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_GROUP_SIZE(group_size, const_group_size, ...) \
if (group_size == 128) { \
constexpr int const_group_size = 128; \
__VA_ARGS__(); \
} else if (group_size == 64) { \
constexpr int const_group_size = 64; \
__VA_ARGS__(); \
}
#define VLLM_DISPATCH_RANK234(NUM_DIMS, ...) \
switch (NUM_DIMS) { \
case 2: { \
constexpr int tensor_rank = 2; \
__VA_ARGS__(); \
break; \
} \
case 3: { \
constexpr int tensor_rank = 3; \
__VA_ARGS__(); \
break; \
} \
case 4: { \
constexpr int tensor_rank = 4; \
__VA_ARGS__(); \
break; \
} \
default: \
TORCH_CHECK(false, "Expects rank 2, 3 or 4 tensors but got ", NUM_DIMS); \
}
\ No newline at end of file
csrc/layernorm_kernels.cu
View file @
b5de0b4e
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "core/batch_invariant.hpp"
#include "quantization/vectorization_utils.cuh"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -8,20 +10,52 @@
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
int
VEC_SIZE
,
int
NUM_DIMS
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
const
int64_t
input_stride_d2
,
// input.stride(-2)
const
int64_t
input_stride_d3
,
// input.stride(-3)
const
int64_t
input_stride_d4
,
// input.stride(-4)
const
int64_t
input_shape_d2
,
// input.size(-2)
const
int64_t
input_shape_d3
,
// input.size(-3)
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
const
scalar_t
*
input_row
;
if
constexpr
(
NUM_DIMS
==
2
)
{
// 2D for layernorm normal case [batch_size, hidden]
input_row
=
input
+
blockIdx
.
x
*
input_stride_d2
;
}
else
if
constexpr
(
NUM_DIMS
==
3
)
{
// 3D for q/k norm [batch_size, num_heads, head_size]
int
batch_idx
=
blockIdx
.
x
/
input_shape_d2
;
int
head_idx
=
blockIdx
.
x
%
input_shape_d2
;
input_row
=
input
+
batch_idx
*
input_stride_d3
+
head_idx
*
input_stride_d2
;
}
else
if
constexpr
(
NUM_DIMS
==
4
)
{
// 4D for transformers model_impl qk norm [batch, seq, head, head_dim]
int
batch_idx
=
blockIdx
.
x
/
(
input_shape_d3
*
input_shape_d2
);
int
remaining
=
blockIdx
.
x
%
(
input_shape_d3
*
input_shape_d2
);
int
seq_idx
=
remaining
/
input_shape_d2
;
int
head_idx
=
remaining
%
input_shape_d2
;
input_row
=
input
+
batch_idx
*
input_stride_d4
+
seq_idx
*
input_stride_d3
+
head_idx
*
input_stride_d2
;
}
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
auto
vec_op
=
[
&
variance
](
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>&
vec
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC_SIZE
;
++
i
)
{
float
x
=
static_cast
<
float
>
(
vec
.
val
[
i
]);
variance
+=
x
*
x
;
}
};
auto
scalar_op
=
[
&
variance
](
const
scalar_t
&
val
)
{
float
x
=
static_cast
<
float
>
(
val
);
variance
+=
x
*
x
;
};
vllm
::
vectorize_read_with_alignment
<
VEC_SIZE
>
(
input_row
,
hidden_size
,
threadIdx
.
x
,
blockDim
.
x
,
vec_op
,
scalar_op
);
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
...
...
@@ -32,10 +66,20 @@ __global__ void rms_norm_kernel(
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
scalar_t
*
out_row
=
out
+
blockIdx
.
x
*
hidden_size
;
auto
*
v_in
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
input_row
);
auto
*
v_w
=
reinterpret_cast
<
const
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
weight
);
auto
*
v_out
=
reinterpret_cast
<
vec_n_t
<
scalar_t
,
VEC_SIZE
>*>
(
out_row
);
for
(
int
i
=
threadIdx
.
x
;
i
<
hidden_size
/
VEC_SIZE
;
i
+=
blockDim
.
x
)
{
vec_n_t
<
scalar_t
,
VEC_SIZE
>
dst
;
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src1
=
v_in
[
i
];
vec_n_t
<
scalar_t
,
VEC_SIZE
>
src2
=
v_w
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
j
++
)
{
float
x
=
static_cast
<
float
>
(
src1
.
val
[
j
]);
dst
.
val
[
j
]
=
((
scalar_t
)(
x
*
s_variance
))
*
src2
.
val
[
j
];
}
v_out
[
i
]
=
dst
;
}
}
...
...
@@ -135,211 +179,6 @@ fused_add_rms_norm_kernel(
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck.
_f16VecPN struct extends _f16Vec to add operations specifically required for
polynomial normalization (poly norm).
The original _f16Vec does not include the sum-of-powers computation or
in-place polynomial normalization logic. */
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16VecPN
:
_f16Vec
<
scalar_t
,
width
>
{
using
Base
=
_f16Vec
<
scalar_t
,
width
>
;
using
Converter
=
typename
Base
::
Converter
;
using
T1
=
typename
Base
::
T1
;
using
T2
=
typename
Base
::
T2
;
using
Base
::
data
;
__device__
auto
sum_pows
()
const
{
float
s2
=
0.0
f
,
s4
=
0.0
f
,
s6
=
0.0
f
;
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
float
y2
=
z
.
y
*
z
.
y
;
float
y4
=
y2
*
y2
;
float
y6
=
y4
*
y2
;
s2
+=
x2
+
y2
;
s4
+=
x4
+
y4
;
s6
+=
x6
+
y6
;
}
return
std
::
make_tuple
(
s2
,
s4
,
s6
);
}
__device__
void
poly_norm_inplace
(
const
float
w2_inv_std
,
const
float
w1_inv_std2
,
const
float
w0_inv_std3
,
const
float
bias
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float
x2
=
z
.
x
*
z
.
x
;
float
x3
=
x2
*
z
.
x
;
z
.
x
=
w2_inv_std
*
z
.
x
+
w1_inv_std2
*
x2
+
w0_inv_std3
*
x3
+
bias
;
float
y2
=
z
.
y
*
z
.
y
;
float
y3
=
y2
*
z
.
y
;
z
.
y
=
w2_inv_std
*
z
.
y
+
w1_inv_std2
*
y2
+
w0_inv_std3
*
y3
+
bias
;
auto
out
=
Converter
::
convert
(
z
);
data
[
i
]
=
out
.
x
;
data
[
i
+
1
]
=
out
.
y
;
}
}
};
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16VecPN
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16VecPN
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
const
_f16VecPN
<
scalar_t
,
width
>*>
(
input
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
auto
[
x2
,
x4
,
x6
]
=
temp
.
sum_pows
();
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
auto
*
__restrict__
out_v
=
reinterpret_cast
<
_f16VecPN
<
scalar_t
,
width
>*>
(
out
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
_f16VecPN
<
scalar_t
,
width
>
temp
=
input_v
[
id
];
temp
.
poly_norm_inplace
(
s_w2_inv_std
,
s_w1_inv_std2
,
s_w0_inv_std3
,
s_bias
);
out_v
[
id
]
=
temp
;
}
}
/* Generic poly_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
poly_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [3]
const
scalar_t
*
__restrict__
bias
,
// [1]
const
float
epsilon
,
const
int
hidden_size
)
{
float
variance
=
0.0
f
;
float
variance2
=
0.0
f
;
float
variance3
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x4
=
x2
*
x2
;
float
x6
=
x4
*
x2
;
variance
+=
x2
;
variance2
+=
x4
;
variance3
+=
x6
;
}
float3
thread_variances
=
make_float3
(
variance
,
variance2
,
variance3
);
struct
SumOp
{
__device__
float3
operator
()(
const
float3
&
a
,
const
float3
&
b
)
const
{
return
make_float3
(
a
.
x
+
b
.
x
,
a
.
y
+
b
.
y
,
a
.
z
+
b
.
z
);
}
};
using
BlockReduce
=
cub
::
BlockReduce
<
float3
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
float3
block_variances
=
BlockReduce
(
reduceStore
).
Reduce
(
thread_variances
,
SumOp
{},
blockDim
.
x
);
variance
=
block_variances
.
x
;
variance2
=
block_variances
.
y
;
variance3
=
block_variances
.
z
;
__shared__
float
s_w2_inv_std
;
__shared__
float
s_w1_inv_std2
;
__shared__
float
s_w0_inv_std3
;
__shared__
float
s_bias
;
if
(
threadIdx
.
x
==
0
)
{
float
w0
=
(
float
)
weight
[
0
];
float
w1
=
(
float
)
weight
[
1
];
float
w2
=
(
float
)
weight
[
2
];
s_bias
=
(
float
)
bias
[
0
];
s_w2_inv_std
=
w2
*
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_w1_inv_std2
=
w1
*
rsqrtf
(
variance2
/
hidden_size
+
epsilon
);
s_w0_inv_std3
=
w0
*
rsqrtf
(
variance3
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x2
=
x
*
x
;
float
x3
=
x2
*
x
;
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
(
scalar_t
)(
x
*
s_w2_inv_std
+
x2
*
s_w1_inv_std2
+
x3
*
s_w0_inv_std3
+
s_bias
);
}
}
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
...
...
@@ -347,21 +186,43 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
if
(
input
.
stride
(
-
1
)
!=
1
)
{
input
=
input
.
contiguous
();
}
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_dims
=
input
.
dim
();
int64_t
input_stride_d2
=
input
.
stride
(
-
2
);
int64_t
input_stride_d3
=
(
num_dims
>=
3
)
?
input
.
stride
(
-
3
)
:
0
;
int64_t
input_stride_d4
=
(
num_dims
>=
4
)
?
input
.
stride
(
-
4
)
:
0
;
int64_t
input_shape_d2
=
(
num_dims
>=
3
)
?
input
.
size
(
-
2
)
:
0
;
int64_t
input_shape_d3
=
(
num_dims
>=
4
)
?
input
.
size
(
-
3
)
:
0
;
// For large num_tokens, use smaller blocks to increase SM concurrency.
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_RANK234
(
num_dims
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
const
int
calculated_vec_size
=
std
::
gcd
(
16
/
sizeof
(
scalar_t
),
hidden_size
);
const
int
block_size
=
std
::
min
(
hidden_size
/
calculated_vec_size
,
max_block_size
);
dim3
block
(
block_size
);
VLLM_DISPATCH_VEC_SIZE
(
calculated_vec_size
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
,
vec_size
,
tensor_rank
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride_d2
,
input_stride_d3
,
input_stride_d4
,
input_shape_d2
,
input_shape_d3
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
});
});
}
...
...
@@ -379,6 +240,8 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
weight
.
scalar_type
()
==
input
.
scalar_type
());
TORCH_CHECK
(
input
.
scalar_type
()
==
residual
.
scalar_type
());
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
...
...
@@ -413,55 +276,11 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
bool
batch_invariant_launch
=
vllm
::
vllm_is_batch_invariant
();
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
&&
!
batch_invariant_launch
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
\ No newline at end of file
#define LAUNCH_FUSED_POLY_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "poly_norm_kernel", [&] { \
vllm::poly_norm_kernel<scalar_t, width><<<grid, block, 0, stream>>>( \
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), bias.data_ptr<scalar_t>(), epsilon, \
hidden_size); \
});
void
poly_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [3]
torch
::
Tensor
&
bias
,
// [1]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
data_ptr
()
!=
input
.
data_ptr
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
out_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
out
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
out_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_POLY_NORM
(
8
);
}
else
{
LAUNCH_FUSED_POLY_NORM
(
0
);
}
}
csrc/ops.h
View file @
b5de0b4e
...
...
@@ -93,12 +93,6 @@ void fused_add_rms_norm(torch::Tensor& input, torch::Tensor& residual,
void
poly_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
bias
,
double
epsilon
);
void
rms_norm_opt
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
double
epsilon
);
void
apply_repetition_penalties_
(
torch
::
Tensor
&
logits
,
const
torch
::
Tensor
&
prompt_mask
,
const
torch
::
Tensor
&
output_mask
,
...
...
csrc/opt/layernorm_kernels_opt.cu
deleted
100644 → 0
View file @
70506d98
#include "type_convert.cuh"
#include "dispatch_utils.h"
#include <torch/cuda.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/native/cuda/MemoryAccess.cuh>
#include <c10/cuda/CUDAMathCompat.h>
#include <ATen/AccumulateType.h>
#include <THC/THCDeviceUtils.cuh>
#ifndef USE_ROCM
#include <cub/cub.cuh>
#else
#include <hipcub/hipcub.hpp>
#endif
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
variance
+=
x
*
x
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
input_stride
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
/* Function specialization in the case of FP16/BF16 tensors.
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
const
int
vec_hidden_size
=
hidden_size
/
width
;
const
int64_t
vec_input_stride
=
input_stride
/
width
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
input_v
[
strided_id
];
temp
+=
residual_v
[
id
];
variance
+=
temp
.
sum_squares
();
residual_v
[
id
]
=
temp
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int64_t
strided_id
=
blockIdx
.
x
*
vec_input_stride
+
idx
;
_f16Vec
<
scalar_t
,
width
>
temp
=
residual_v
[
id
];
temp
*=
s_variance
;
temp
*=
weight_v
[
idx
];
input_v
[
strided_id
]
=
temp
;
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
int64_t
input_stride
,
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
input_stride
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStore
;
variance
=
BlockReduce
(
reduceStore
).
Reduce
(
variance
,
cub
::
Sum
{},
blockDim
.
x
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
input_stride
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
// namespace vllm
template
<
typename
T
,
int
reducesize
=
C10_WARP_SIZE
>
__inline__
__device__
T
WarpReduceSum_NEW
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
reducesize
/
2
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
WARP_SHFL_DOWN
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
,
int
block_size
=
512
>
__inline__
__device__
T
BlockReduceSum_NEW
(
T
val
,
T
*
shared
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
val
=
WarpReduceSum_NEW
<
T
>
(
val
);
if
constexpr
(
block_size
==
C10_WARP_SIZE
)
{
return
val
;
}
else
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
if
(
lid
==
0
&&
wid
<
share_size
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
share_size
)
{
val
=
WarpReduceSum_NEW
<
T
,
share_size
>
(
shared
[
lid
]);
}
return
val
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_add_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
residual
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
scalar_t
residual_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
*
(
LoadT
*
)
residual_vec
=
*
(
LoadT
*
)(
residual
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
residual_vec
[
ii
]
+=
intput_vec
[
ii
];
val
+=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
residual_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
residual_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
residual
+
idx
)
=*
(
LoadT
*
)
residual_vec
;
*
(
LoadT
*
)(
input
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
template
<
typename
scalar_t
,
typename
T_ACC
,
int
Vec
=
4
,
int
block_size
=
512
>
__global__
void
fused_rms_kernel_opt
(
scalar_t
*
input
,
scalar_t
*
output
,
scalar_t
*
gamma
,
int
cols
,
T_ACC
eps
)
{
constexpr
int
share_size
=
block_size
/
C10_WARP_SIZE
;
__shared__
T_ACC
val_shared
[
share_size
];
__shared__
T_ACC
s_rstd
;
T_ACC
val
=
0
;
int
i
=
blockIdx
.
x
;
int
j
=
threadIdx
.
x
;
int
tcol
=
cols
/
Vec
;
using
LoadT
=
at
::
native
::
memory
::
aligned_vector
<
scalar_t
,
Vec
>
;
scalar_t
intput_vec
[
Vec
];
T_ACC
trstd
;
int64_t
idx
=
i
*
tcol
+
j
;
idx
*=
Vec
;
if
(
j
<
tcol
)
{
*
(
LoadT
*
)
intput_vec
=
*
(
LoadT
*
)(
input
+
idx
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
)
{
val
+=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
static_cast
<
T_ACC
>
(
intput_vec
[
ii
]);
}
}
val
=
BlockReduceSum_NEW
<
T_ACC
,
block_size
>
(
val
,
val_shared
);
if
(
j
==
0
)
s_rstd
=
c10
::
cuda
::
compat
::
rsqrt
(
val
/
cols
+
eps
);
__syncthreads
();
trstd
=
s_rstd
;
if
(
j
<
tcol
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
Vec
;
ii
++
){
int
jj
=
j
*
Vec
+
ii
;
intput_vec
[
ii
]
=
static_cast
<
T_ACC
>
(
intput_vec
[
ii
])
*
trstd
*
static_cast
<
T_ACC
>
(
gamma
[
jj
]);
}
*
(
LoadT
*
)(
output
+
idx
)
=*
(
LoadT
*
)
intput_vec
;
}
}
void
rms_norm_opt
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
input
.
stride
(
-
1
)
==
1
);
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int64_t
input_stride
=
input
.
stride
(
-
2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
out_data
=
out
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_rms_kernel_opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
out_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
input_stride
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), input_stride, \
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
epsilon, num_tokens, hidden_size); \
});
void
fused_add_rms_norm_opt
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
double
epsilon
)
{
TORCH_CHECK
(
residual
.
is_contiguous
());
TORCH_CHECK
(
weight
.
is_contiguous
());
int
hidden_size
=
input
.
size
(
-
1
);
int64_t
input_stride
=
input
.
stride
(
-
2
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
hidden_size
%
16
==
0
&&
hidden_size
<=
16384
&&
ptrs_are_aligned
){
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
input
.
scalar_type
(),
"fused_add_rms_norm_kernel"
,
[
&
]
{
using
T_ACC
=
at
::
acc_type
<
scalar_t
,
true
>
;
T_ACC
eps
=
epsilon
;
scalar_t
*
self_data
=
input
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
other_data
=
residual
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
scalar_t
*
weight_data
=
weight
.
expect_contiguous
()
->
data_ptr
<
scalar_t
>
();
if
(
hidden_size
<=
1024
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
128
><<<
num_tokens
,
128
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
2048
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
256
><<<
num_tokens
,
256
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
if
(
hidden_size
<=
4096
){
if
(
num_tokens
>
1200
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
512
><<<
num_tokens
,
512
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
4
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
}
else
if
(
hidden_size
<=
8192
){
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
8
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
else
{
fused_add_rms_kernel_opt
<
scalar_t
,
T_ACC
,
16
,
1024
><<<
num_tokens
,
1024
,
0
,
stream
>>>
(
self_data
,
other_data
,
weight_data
,
hidden_size
,
eps
);
}
});
}
else
{
dim3
grid
(
num_tokens
);
/* This kernel is memory-latency bound in many scenarios.
When num_tokens is large, a smaller block size allows
for increased block occupancy on CUs and better latency
hiding on global mem ops. */
const
int
max_block_size
=
(
num_tokens
<
256
)
?
1024
:
256
;
dim3
block
(
std
::
min
(
hidden_size
,
max_block_size
));
/*If the tensor types are FP16/BF16, try to use the optimized kernel
with packed + vectorized ops.
Max optimization is achieved with a width-8 vector of FP16/BF16s
since we can load at most 128 bits at once in a global memory op.
However, this requires each tensor's data to be aligned to 16
bytes.
*/
constexpr
int
vector_width
=
8
;
constexpr
int
req_alignment_bytes
=
vector_width
*
2
;
// vector_width * sizeof(bfloat16 or float16) (float32
// falls back to non-vectorized version anyway)
bool
ptrs_are_aligned
=
inp_ptr
%
req_alignment_bytes
==
0
&&
res_ptr
%
req_alignment_bytes
==
0
&&
wt_ptr
%
req_alignment_bytes
==
0
;
bool
offsets_are_multiple_of_vector_width
=
hidden_size
%
vector_width
==
0
&&
input_stride
%
vector_width
==
0
;
if
(
ptrs_are_aligned
&&
offsets_are_multiple_of_vector_width
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
LAUNCH_FUSED_ADD_RMS_NORM
(
0
);
}
}
}
\ No newline at end of file
csrc/torch_bindings.cpp
View file @
b5de0b4e
...
...
@@ -199,19 +199,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"apply_repetition_penalties_"
,
torch
::
kCUDA
,
&
apply_repetition_penalties_
);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
ops
.
def
(
"rms_norm_opt(Tensor! out, Tensor input, Tensor weight, float epsilon) -> "
"()"
);
ops
.
impl
(
"rms_norm_opt"
,
torch
::
kCUDA
,
&
rms_norm_opt
);
// In-place fused Add and RMS Normalization. (opt)
ops
.
def
(
"fused_add_rms_norm_opt(Tensor! input, Tensor! residual, Tensor weight, "
"float epsilon) -> ()"
);
ops
.
impl
(
"fused_add_rms_norm_opt"
,
torch
::
kCUDA
,
&
fused_add_rms_norm_opt
);
// Layernorm-quant
// Apply Root Mean Square (RMS) Normalization to the input tensor.
// ops.def(
...
...
requirements/rocm.txt
View file @
b5de0b4e
...
...
@@ -26,11 +26,11 @@ setuptools_scm>=8
cmake==3.29
quart
fastrlock==0.8.3
cupy==12.3.0
#
cupy==12.3.0
torch
>=2.5.1,<=2.7.1
torchvision
>=0.20.1,<
=0.2
2
.0
triton == 3.
1
torch
==2.9.0
torchvision
=
=0.2
4
.0
triton == 3.
3.0
flash_attn == 2.6.1
flash_mla == 1.0.0
lightop == 0.6.0
...
...
setup.py
View file @
b5de0b4e
...
...
@@ -36,7 +36,7 @@ if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1:
skip_vllm_build
=
True
def
prepare_so_files
():
source_dir
=
"so.tar.gz"
source_dir
=
"
csrc/
so.tar.gz"
target_dir
=
"vllm"
if
not
os
.
path
.
exists
(
source_dir
):
...
...
@@ -542,9 +542,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if
sha
!=
'Unknown'
:
if
sha
is
None
:
sha
=
get_sha
(
vllm_root
)
version
=
'das.opt1.rc2.'
+
sha
[:
7
]
version
=
'das.opt1.rc2.
torch29.
'
+
sha
[:
7
]
else
:
version
=
'das.opt1.rc2'
version
=
'das.opt1.rc2
.torch29
'
# dtk version
...
...
vllm/_custom_ops.py
View file @
b5de0b4e
...
...
@@ -286,17 +286,6 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
torch
.
ops
.
_C
.
fused_add_rms_norm
(
input
,
residual
,
weight
,
epsilon
)
# layer norm ops (opt)
def
rms_norm_opt
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
rms_norm_opt
(
out
,
input
,
weight
,
epsilon
)
def
fused_add_rms_norm_opt
(
input
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
torch
.
ops
.
_C
.
fused_add_rms_norm_opt
(
input
,
residual
,
weight
,
epsilon
)
def
poly_norm
(
out
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
epsilon
:
float
)
->
None
:
# TODO: Remove this contiguous call when the kernel is updated to support non-contiguous input
...
...
vllm/model_executor/layers/layernorm.py
View file @
b5de0b4e
...
...
@@ -23,7 +23,8 @@ def rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon
:
float
)
->
torch
.
Tensor
:
from
vllm
import
_custom_ops
as
ops
out
=
torch
.
empty_like
(
x
)
if
envs
.
VLLM_USE_OPT_OP
:
# if envs.VLLM_USE_OPT_OP:
if
False
:
ops
.
rms_norm_opt
(
out
,
x
,
...
...
@@ -44,7 +45,8 @@ def fused_add_rms_norm(
x
:
torch
.
Tensor
,
residual
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
variance_epsilon
:
float
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
from
vllm
import
_custom_ops
as
ops
if
envs
.
VLLM_USE_OPT_OP
:
# if envs.VLLM_USE_OPT_OP:
if
False
:
ops
.
fused_add_rms_norm_opt
(
x
,
residual
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment