Commit fdd0f2f4 authored by Woosuk Kwon's avatar Woosuk Kwon
Browse files

Minor

parent 7f985166
......@@ -216,7 +216,7 @@ class Scheduler:
self.block_manager.fork(parent_seq, seq)
# Append a new token to the sequence.
seq.append(next_token)
seq.append([next_token])
# Check if the sequence has generated a stop token.
if next_token in stop_token_ids:
......
......@@ -13,7 +13,7 @@ class Sampler(nn.Module):
embedding: torch.Tensor,
) -> None:
super().__init__()
self.embedding = embedding.t() # [hidden_size, vocab_size]
self.embedding = embedding # [vocab_size, hidden_size]
def forward(
self,
......@@ -31,7 +31,7 @@ class Sampler(nn.Module):
hidden_states = hidden_states[last_token_indicies]
# Get the logits for the next tokens.
logits = torch.matmul(hidden_states, self.embedding)
logits = torch.matmul(hidden_states, self.embedding.t())
# Sample the next tokens.
# TODO(woosuk): Implement other sampling methods.
......
......@@ -165,6 +165,7 @@ class Worker:
output = self.model(
input_ids=input_tokens,
positions=input_positions,
kv_caches=self.gpu_cache,
input_metadata=input_metadata,
cache_events=cache_events,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment