Commit 4730762e authored by Patrick Devine's avatar Patrick Devine Committed by Michael Yang
Browse files

add safetensors version

parent d88582df
......@@ -20,7 +20,7 @@ type LlamaModel struct {
ModelData
}
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
data := r.storage.(*pytorch.HalfStorage).Data
......@@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
matches := re.FindAllStringSubmatch(l.Name, -1)
if len(matches) > 0 {
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaLayerHandler
l.WriterTo = wt
switch l.WriterTo.(type) {
case torchWriterTo:
wt := l.WriterTo.(torchWriterTo)
wt.handler = llamaTorchLayerHandler
l.WriterTo = wt
case safetensorWriterTo:
wt := l.WriterTo.(safetensorWriterTo)
wt.handler = mistralLayerHandler
l.WriterTo = wt
}
}
m.Tensors = append(m.Tensors, l)
}
......
......@@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
return nil, fmt.Errorf("No architecture specified to convert")
case 1:
switch params.Architectures[0] {
case "LlamaForCausalLM":
return &LlamaModel{
ModelData{
Name: name,
Path: dirPath,
Params: params,
Format: m,
},
}, nil
case "MistralForCausalLM":
return &MistralModel{
ModelData{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment