Unverified Commit c9ee3d35 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Fix model forward grad (#628)

parent 41d1f677
......@@ -360,6 +360,7 @@ class ChatGLMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -368,6 +368,7 @@ class DbrxForCausalLM(nn.Module):
)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -601,6 +601,7 @@ class Grok1ModelForCausalLM(nn.Module):
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -275,6 +275,7 @@ class LlamaForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -31,6 +31,7 @@ class LlamaForClassification(nn.Module):
)
self.eos_token_id = config.eos_token_id
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -95,6 +95,7 @@ class LlavaLlamaForCausalLM(nn.Module):
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
......
......@@ -106,6 +106,7 @@ class LlavaVidForCausalLM(nn.Module):
return image_features
@torch.no_grad()
def forward(
self,
input_ids: torch.LongTensor,
......
......@@ -283,6 +283,7 @@ class MiniCPMForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -460,6 +460,7 @@ class MixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -322,6 +322,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -237,6 +237,7 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -261,6 +261,7 @@ class Qwen2ForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -355,6 +355,7 @@ class Qwen2MoeForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
......@@ -235,6 +235,7 @@ class StableLmForCausalLM(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment