Unverified Commit cb226321 authored by Chuan (Richard) Li's avatar Chuan (Richard) Li Committed by GitHub
Browse files

[Bugfix][Minor] Fix potential NameError in mamba backend selector and misc typos (#35886)


Signed-off-by: default avatarLi <chuali@amd.com>
parent e054f152
...@@ -369,7 +369,7 @@ class KimiK25ForConditionalGeneration( ...@@ -369,7 +369,7 @@ class KimiK25ForConditionalGeneration(
target_dtype = next(self.vision_tower.parameters()).dtype target_dtype = next(self.vision_tower.parameters()).dtype
pixel_values = pixel_values.to(target_dtype) pixel_values = pixel_values.to(target_dtype)
assert isinstance(grid_thws, torch.Tensor), ( assert isinstance(grid_thws, torch.Tensor), (
f"expect grid_thws to be a tensor, get {type(grid_thws)}" f"expect grid_thws to be a tensor, got {type(grid_thws)}"
) )
# In some cases (e.g. with merger), grid_thws has an extra middle dimension # In some cases (e.g. with merger), grid_thws has an extra middle dimension
grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1]) grid_thws = grid_thws.reshape(-1, grid_thws.shape[-1])
......
...@@ -749,7 +749,10 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat ...@@ -749,7 +749,10 @@ class FlexAttentionMetadataBuilder(AttentionMetadataBuilder[FlexAttentionMetadat
prefix_kv_lens = None prefix_kv_lens = None
suffix_kv_lens = None suffix_kv_lens = None
if use_cascade: if use_cascade:
raise NotImplementedError("Not yet my friend") raise NotImplementedError(
"Cascade prefix attention is not yet implemented "
"for FlexAttention backend"
)
block_size = self.kv_cache_spec.block_size block_size = self.kv_cache_spec.block_size
max_possible_seq_len = self.model_config.max_model_len max_possible_seq_len = self.model_config.max_model_len
......
...@@ -253,7 +253,7 @@ def make_local_attention_virtual_batches( ...@@ -253,7 +253,7 @@ def make_local_attention_virtual_batches(
# seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1] # seqlens_q_local = [2, 2, 1, 4, 4, 1, 4, 1]
# #
# First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1]) # First Get batched arange. (E.g., [2, 4, 2] -> [0, 1, 0, 1, 2, 3, 0, 1])
# (TODO: max a utility to share this code with _prepare_inputs) # (TODO: make a utility to share this code with _prepare_inputs)
# arange step 1. [2, 4, 2] -> [2, 6, 8] # arange step 1. [2, 4, 2] -> [2, 6, 8]
cu_num_blocks = np.cumsum(local_blocks) cu_num_blocks = np.cumsum(local_blocks)
virtual_batches = cu_num_blocks[-1] virtual_batches = cu_num_blocks[-1]
......
...@@ -149,8 +149,8 @@ def _cached_get_mamba_attn_backend( ...@@ -149,8 +149,8 @@ def _cached_get_mamba_attn_backend(
selected_backend = MambaAttentionBackendEnum[backend_name] selected_backend = MambaAttentionBackendEnum[backend_name]
except KeyError as e: except KeyError as e:
raise ValueError( raise ValueError(
f"Invalid mamba attention backend type: '{backend_name}'. Valid " f"Invalid mamba attention backend type: '{mamba_type}'. Valid "
f"backends are: {list(MambaAttentionBackendEnum.__members__.keys())}" f"types are: {list(MAMBA_TYPE_TO_BACKEND_MAP.keys())}"
) from e ) from e
mamba_attn_backend = selected_backend.get_class() mamba_attn_backend = selected_backend.get_class()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment