"examples/git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "47ecb2238749e46d3bcfa30523850c92864a1837"
Unverified Commit a7ca2972 authored by gongenlei's avatar gongenlei Committed by GitHub
Browse files

[coati] Fix LlamaCritic (#3475)



* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------
Co-authored-by: default avatargongenlei <gongenlei@baidu.com>
parent 8f2c55f9
from typing import Optional from typing import Optional
import torch
import torch.nn as nn import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM from transformers import LlamaConfig, LlamaModel
from ..base import Critic from ..base import Critic
...@@ -28,11 +27,11 @@ class LlamaCritic(Critic): ...@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None: **kwargs) -> None:
if pretrained is not None: if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained) model = LlamaModel.from_pretrained(pretrained)
elif config is not None: elif config is not None:
model = LlamaForCausalLM(config) model = LlamaModel(config)
else: else:
model = LlamaForCausalLM(LlamaConfig()) model = LlamaModel(LlamaConfig())
if checkpoint: if checkpoint:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment