Unverified Commit 5d7b0502 authored by TechxGenus's avatar TechxGenus Committed by GitHub
Browse files

Fix fused models for tf >= 4.39 (#418)

parent 0fa9a2c1
......@@ -83,6 +83,14 @@ class LlamaLikeModel(nn.Module):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@property
def embed_tokens(self):
return self.embedding
@property
def layers(self):
return self.blocks
@torch.inference_mode()
def forward(
......
......@@ -86,7 +86,7 @@ common_setup_kwargs = {
requirements = [
"torch>=2.0.1",
"transformers>=4.35.0,<=4.38.2",
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"typing_extensions>=4.8.0",
"accelerate",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment