Unverified Commit 8f4f77a7 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[BugFix] Fix false assertion with spec-decode=[2,4,..] and TP>2 (#29036)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 22e44ad5
...@@ -921,7 +921,7 @@ class CompilationConfig: ...@@ -921,7 +921,7 @@ class CompilationConfig:
self, uniform_decode_query_len: int, tensor_parallel_size: int self, uniform_decode_query_len: int, tensor_parallel_size: int
): ):
multiple_of = uniform_decode_query_len multiple_of = uniform_decode_query_len
if tensor_parallel_size > 1: if tensor_parallel_size > 1 and self.pass_config.enable_sequence_parallelism:
multiple_of = max(uniform_decode_query_len, tensor_parallel_size) multiple_of = max(uniform_decode_query_len, tensor_parallel_size)
if ( if (
multiple_of % uniform_decode_query_len != 0 multiple_of % uniform_decode_query_len != 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment