Unverified Commit c9ee3d35 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix model forward grad (#628)

parent 41d1f677
...@@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module): ...@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module):
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module): ...@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights # Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module): ...@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
) )
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module): ...@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
return image_features return image_features
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
......
...@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
return image_features return image_features
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
......
...@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler() self.sampler = Sampler()
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
...@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module): ...@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward( def forward(
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment