Commit 0e640807 authored by zhuwenwen's avatar zhuwenwen
Browse files

update attention kernels and version

parent 69341fde
...@@ -91,7 +91,7 @@ __device__ void paged_attention_kernel( ...@@ -91,7 +91,7 @@ __device__ void paged_attention_kernel(
const int max_num_blocks_per_seq, const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads] const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride, const int q_stride, const int kv_block_stride, const int kv_head_stride,
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks, const float k_scale, const float v_scale, const int tp_rank, const int blocksparse_local_blocks,
const int blocksparse_vert_stride, const int blocksparse_block_size, const int blocksparse_vert_stride, const int blocksparse_block_size,
const int blocksparse_head_sliding_step) {} const int blocksparse_head_sliding_step) {}
...@@ -345,7 +345,7 @@ __device__ void paged_attention_kernel( ...@@ -345,7 +345,7 @@ __device__ void paged_attention_kernel(
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>( Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
k_ptr + offset1 * BLOCK_SIZE * x + offset2); k_ptr + offset1 * BLOCK_SIZE * x + offset2);
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>( k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
k_vec_quant, kv_scale); k_vec_quant, k_scale);
} }
} }
} }
......
...@@ -394,21 +394,21 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -394,21 +394,21 @@ def get_version_add(sha: Optional[str] = None) -> str:
# torch version # torch version
version += ".torch" + torch.__version__[:5] version += ".torch" + torch.__version__[:5]
new_version_content = f"""\ new_version_content = f"""
import warnings import warnings
try: try:
import vllm.commit_id import vllm.commit_id
__commit__ = vllm.commit_id.__commit__ __commit__ = vllm.commit_id.__commit__
except Exception as e: except Exception as e:
warnings.warn(f"Failed to read commit hash:\\n{e}", warnings.warn(f"Failed to read commit hash:\\n + str(e)",
RuntimeWarning, RuntimeWarning,
stacklevel=2) stacklevel=2)
__commit__ = "COMMIT_HASH_PLACEHOLDER" __commit__ = "COMMIT_HASH_PLACEHOLDER"
__version__ = "0.5.3.post1" __version__ = "0.5.3.post1"
__dcu_version__ = '0.5.3.post1+{version}' __dcu_version__ = f'0.5.3.post1+{version}'
""".format(version=version) """
with open(add_version_path, encoding="utf-8",mode="w") as file: with open(add_version_path, encoding="utf-8",mode="w") as file:
file.write(new_version_content) file.write(new_version_content)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment