Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cbbc9044
Unverified
Commit
cbbc9044
authored
Jul 30, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jul 30, 2024
Browse files
[Kernel] Squash a few more warnings (#6914)
parent
5cf9254a
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
8 additions
and
5 deletions
+8
-5
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+2
-2
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+0
-2
csrc/quantization/fp8/amd/quant_utils.cuh
csrc/quantization/fp8/amd/quant_utils.cuh
+2
-0
csrc/quantization/fp8/nvidia/quant_utils.cuh
csrc/quantization/fp8/nvidia/quant_utils.cuh
+2
-0
csrc/quantization/squeezellm/quant_cuda_kernel.cu
csrc/quantization/squeezellm/quant_cuda_kernel.cu
+2
-1
No files found.
csrc/attention/attention_kernels.cu
View file @
cbbc9044
...
...
@@ -706,7 +706,7 @@ void paged_attention_v1_launcher(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
...
...
@@ -865,7 +865,7 @@ void paged_attention_v2_launcher(
int
kv_block_stride
=
key_cache
.
stride
(
0
);
int
kv_head_stride
=
key_cache
.
stride
(
1
);
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
[[
maybe_unused
]]
int
thread_group_size
=
MAX
(
WARP_SIZE
/
BLOCK_SIZE
,
1
);
assert
(
head_size
%
thread_group_size
==
0
);
// NOTE: alibi_slopes is optional.
...
...
csrc/quantization/aqlm/gemm_kernels.cu
View file @
cbbc9044
...
...
@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
__syncthreads
();
float
res
=
0
;
int
iters
=
(
prob_k
/
8
-
1
)
/
(
8
*
32
)
+
1
;
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
...
...
csrc/quantization/fp8/amd/quant_utils.cuh
View file @
cbbc9044
...
...
@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/fp8/nvidia/quant_utils.cuh
View file @
cbbc9044
...
...
@@ -508,6 +508,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
template
<
typename
Tout
,
typename
Tin
,
Fp8KVCacheDataType
kv_dt
>
...
...
@@ -520,6 +521,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert
(
false
);
return
{};
// Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
...
...
csrc/quantization/squeezellm/quant_cuda_kernel.cu
View file @
cbbc9044
...
...
@@ -203,7 +203,8 @@ void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
#endif
mat
.
data_ptr
<
int
>
(),
#ifndef USE_ROCM
(
half2
*
)
mul
.
data
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
(
half2
*
)
mul
.
data_ptr
<
at
::
Half
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
#else
(
float2
*
)
mul
.
data_ptr
<
float
>
(),
(
__half
*
)
lookup_table
.
data_ptr
<
at
::
Half
>
(),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment