Unverified Commit a7ca2972 authored by gongenlei's avatar gongenlei Committed by GitHub
Browse files

[coati] Fix LlamaCritic (#3475)



* mv LlamaForCausalLM to LlamaModel

* rm unused imports

---------
Co-authored-by: default avatargongenlei <gongenlei@baidu.com>
parent 8f2c55f9
from typing import Optional
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
from transformers import LlamaConfig, LlamaModel
from ..base import Critic
......@@ -28,11 +27,11 @@ class LlamaCritic(Critic):
**kwargs) -> None:
if pretrained is not None:
model = LlamaForCausalLM.from_pretrained(pretrained)
model = LlamaModel.from_pretrained(pretrained)
elif config is not None:
model = LlamaForCausalLM(config)
model = LlamaModel(config)
else:
model = LlamaForCausalLM(LlamaConfig())
model = LlamaModel(LlamaConfig())
if checkpoint:
model.gradient_checkpointing_enable()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment