Unverified Commit b1e5a33a authored by Lifu Huang's avatar Lifu Huang Committed by GitHub
Browse files

Eliminate stream sync to speed up LoRA batch init (#6960)

parent 9d5fa68b
...@@ -137,7 +137,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -137,7 +137,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self.A_buffer_gate_up = A_buffer self.A_buffer_gate_up = A_buffer
if self.lora_backend.fuse_stacked_lora_b: if self.lora_backend.fuse_stacked_lora_b:
# B_buffer_gate_up: (num_lora, 2 * output_dim, r) # B_buffer_gate_up: (num_lora, 2 * output_dim, r)
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None: if getattr(self, "B_buffer_gate_up", None) is None:
self.B_buffer_gate_up = torch.empty( self.B_buffer_gate_up = torch.empty(
( (
B_buffer[0].shape[0], B_buffer[0].shape[0],
...@@ -202,7 +202,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -202,7 +202,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2] output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r) # B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None: if getattr(self, "B_buffer_qkv", None) is None:
self.B_buffer_qkv = torch.empty( self.B_buffer_qkv = torch.empty(
( (
B_buffer_q[0].shape[0], B_buffer_q[0].shape[0],
...@@ -221,20 +221,17 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): ...@@ -221,20 +221,17 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
) )
# Offsets of q/k/v in output dimension # Offsets of q/k/v in output dimension
if not hasattr(self, "output_offset") or self.output_offset is None: if getattr(self, "output_offset", None) is None:
self.output_offset = torch.empty( self.output_offset = torch.tensor(
4, dtype=torch.int32, device=B_buffer_q.device [
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
) )
self.output_offset[:4] = torch.tensor(
[
0,
output_dim_q,
output_dim_q + output_dim_kv,
output_dim_q + 2 * output_dim_kv,
],
dtype=torch.int32,
device=B_buffer_q.device,
)
# For computing number of launched blocks # For computing number of launched blocks
self.max_qkv_out_dim = max(output_dim_q, output_dim_kv) self.max_qkv_out_dim = max(output_dim_q, output_dim_kv)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment