Unverified Commit 5d7b0502 authored by TechxGenus's avatar TechxGenus Committed by GitHub
Browse files

Fix fused models for tf >= 4.39 (#418)

parent 0fa9a2c1
...@@ -83,6 +83,14 @@ class LlamaLikeModel(nn.Module): ...@@ -83,6 +83,14 @@ class LlamaLikeModel(nn.Module):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks) self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm self.norm = norm
self.last_forward_num_tokens = 0 self.last_forward_num_tokens = 0
@property
def embed_tokens(self):
return self.embedding
@property
def layers(self):
return self.blocks
@torch.inference_mode() @torch.inference_mode()
def forward( def forward(
......
...@@ -86,7 +86,7 @@ common_setup_kwargs = { ...@@ -86,7 +86,7 @@ common_setup_kwargs = {
requirements = [ requirements = [
"torch>=2.0.1", "torch>=2.0.1",
"transformers>=4.35.0,<=4.38.2", "transformers>=4.35.0",
"tokenizers>=0.12.1", "tokenizers>=0.12.1",
"typing_extensions>=4.8.0", "typing_extensions>=4.8.0",
"accelerate", "accelerate",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment