Unverified Commit d3b4b997 authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

app: add code for macOS and Windows apps under 'app' (#12933)



* app: add code for macOS and Windows apps under 'app'

* app: add readme

* app: windows and linux only for now

* ci: fix ui CI validation

---------
Co-authored-by: default avatarjmorganca <jmorganca@gmail.com>
parent a4770107
package updater
import (
"archive/zip"
"io/fs"
"os"
"path/filepath"
"strings"
"testing"
)
func TestDoUpgrade(t *testing.T) {
tmpDir := t.TempDir()
BundlePath = filepath.Join(tmpDir, "Ollama.app")
appContents := filepath.Join(BundlePath, "Contents")
appBackupDir = filepath.Join(tmpDir, "backup")
appContentsOld := filepath.Join(appBackupDir, "Ollama.app", "Contents")
UpdateStageDir = filepath.Join(tmpDir, "updates")
UpgradeMarkerFile = filepath.Join(tmpDir, "upgraded")
bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip")
err := os.MkdirAll(filepath.Join(appContents, "MacOS"), 0o755)
if err != nil {
t.Fatal("failed to create empty dirs")
}
err = os.MkdirAll(filepath.Join(BundlePath, "Contents", "Resources"), 0o755)
if err != nil {
t.Fatal("failed to create empty dirs")
}
err = os.MkdirAll(filepath.Dir(bundle), 0o755)
if err != nil {
t.Fatal("failed to create empty dirs")
}
// No update file, simple failure scenario
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure without download")
} else if !strings.Contains(err.Error(), "failed to lookup downloads") {
t.Fatalf("unexpected error: %s", err.Error())
}
// Start with an unreadable zip file
if err := os.WriteFile(bundle, []byte{0x4b, 0x50, 0x40, 0x03, 0x00, 0x0a, 0x00}, 0o755); err != nil {
t.Fatalf("failed to create intentionally corrupt zip file: %s", err)
}
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure with corrupt zip file")
} else if !strings.Contains(err.Error(), "unable to open upgrade bundle") {
t.Fatalf("unexpected error with corrupt zip file: %s", err)
}
// Generate valid (partial) zip file for remaining scenarios
if err := zipCreationHelper(bundle, []testPayload{
{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("would be app binary"),
},
{
Name: "Ollama.app/Contents/Resources/ollama",
Body: []byte("would be the cli"),
},
{
Name: "Ollama.app/Contents/Resources/dummy",
Body: []byte("./ollama"),
Mode: os.ModeSymlink,
},
}); err != nil {
t.Fatal(err)
}
// Permission failure on rename
if err := os.Chmod(BundlePath, 0o500); err != nil {
t.Fatal("failed to remove write permission")
}
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure with no permission to rename Contents")
} else if !strings.Contains(err.Error(), "permission problems") {
t.Fatalf("unexpected error with permission failure: %s", err)
}
if err := os.Chmod(BundlePath, 0o755); err != nil {
t.Fatal("failed to restore write permission")
}
// Prior failed upgrade
if err := os.MkdirAll(appContentsOld, 0o755); err != nil {
t.Fatal("failed to create empty dirs")
}
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure with old contents existing")
} else if !strings.Contains(err.Error(), "prior upgrade failed") {
t.Fatalf("unexpected error with old contents: %s", err)
}
if err := os.RemoveAll(appBackupDir); err != nil {
t.Fatal("failed to cleanup dir")
}
// TODO - a failure mode where we revert the backup
// Happy path
if err := DoUpgrade(false); err != nil {
t.Fatalf("unexpected error with clean setup: %s", err)
}
if _, err := os.Stat(appContentsOld); err != nil {
t.Fatalf("missing %s", appContentsOld)
}
if _, err := os.Stat(UpgradeMarkerFile); err != nil {
t.Fatalf("missing marker %s", UpgradeMarkerFile)
}
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama")); err != nil {
t.Fatalf("missing new App")
}
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "Resources", "ollama")); err != nil {
t.Fatalf("missing new cli")
}
// Cleanup before next attempt
if err := DoPostUpgradeCleanup(); err != nil {
t.Fatal("failed to cleanup dir")
}
err = os.MkdirAll(filepath.Dir(bundle), 0o755)
if err != nil {
t.Fatal("failed to create empty dirs")
}
// Zip file with one corrupt file within to trigger a rollback
if err := os.WriteFile(bundle, corruptZipData, 0o755); err != nil {
t.Fatalf("failed to create intentionally corrupt zip file: %s", err)
}
if err := DoUpgrade(false); err == nil {
t.Fatal("expected failure with corrupt zip file")
} else if !strings.Contains(err.Error(), "failed to open bundle file") {
t.Fatalf("unexpected error with corrupt zip file: %s", err)
}
// Make sure things were restored on partial failure
if _, err := os.Stat(appContents); err != nil {
t.Fatalf("missing %s", appContents)
}
if _, err := os.Stat(appContentsOld); err == nil {
t.Fatal("old contents still exists")
}
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "MacOS", "Ollama")); err != nil {
t.Fatalf("missing old App")
}
if _, err := os.Stat(filepath.Join(BundlePath, "Contents", "Resources", "ollama")); err != nil {
t.Fatalf("missing old cli")
}
}
func TestDoUpgradeAtStartup(t *testing.T) {
tmpDir := t.TempDir()
BundlePath = filepath.Join(tmpDir, "Ollama.app")
appBackupDir = filepath.Join(tmpDir, "backup")
UpdateStageDir = filepath.Join(tmpDir, "updates")
UpgradeMarkerFile = filepath.Join(tmpDir, "upgraded")
bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip")
if err := DoUpgradeAtStartup(); err == nil {
t.Fatal("expected failure without download")
} else if !strings.Contains(err.Error(), "failed to lookup downloads") {
t.Fatalf("unexpected error: %s", err.Error())
}
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
t.Fatal("failed to create empty dirs")
}
if err := zipCreationHelper(bundle, []testPayload{
{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("would be app binary"),
},
{
Name: "Ollama.app/Contents/Resources/ollama",
Body: []byte("would be the cli"),
},
{
Name: "Ollama.app/Contents/Resources/dummy",
Body: []byte("./ollama"),
Mode: os.ModeSymlink,
},
}); err != nil {
t.Fatal(err)
}
if err := DoUpgradeAtStartup(); err != nil {
t.Fatalf("unexpected error with verification failure: %s", err)
}
if _, err := os.Stat(bundle); err == nil {
t.Fatalf("unverified bundle still exists %s", bundle)
}
}
func TestVerifyDownloadFailures(t *testing.T) {
tmpDir := t.TempDir()
BundlePath = filepath.Join(tmpDir, "Ollama.app")
UpdateStageDir = filepath.Join(tmpDir, "staging")
bundle := filepath.Join(UpdateStageDir, "foo", "ollama-darwin.zip")
if err := os.MkdirAll(filepath.Dir(bundle), 0o755); err != nil {
t.Fatal("failed to create empty dirs")
}
tests := []struct {
n string
in []testPayload
expected string
}{
{"breakout", []testPayload{
{
Name: "Ollama.app/",
Body: []byte{},
}, {
Name: "Ollama.app/Resources/ollama",
Body: []byte("cli payload here"),
}, {
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("../../../../breakout"),
Mode: os.ModeSymlink,
},
}, "bundle contains link outside"},
{"absolute", []testPayload{{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("/etc/foo"),
Mode: os.ModeSymlink,
}}, "bundle contains absolute"},
{"missing", []testPayload{{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte("../nothere"),
Mode: os.ModeSymlink,
}}, "no such file or directory"},
{"unsigned", []testPayload{{
Name: "Ollama.app/Contents/MacOS/Ollama",
Body: []byte{0xfa, 0xcf, 0xfe, 0xed, 0x00, 0x0c, 0x01, 0x00},
}}, "signature verification failed"},
}
for _, tt := range tests {
t.Run(tt.n, func(t *testing.T) {
_ = os.Remove(bundle)
if err := zipCreationHelper(bundle, tt.in); err != nil {
t.Fatal(err)
}
err := VerifyDownload()
if err == nil || !strings.Contains(err.Error(), tt.expected) {
t.Fatalf("expected \"%s\" got %s", tt.expected, err)
}
})
}
}
// One file has been corrupted to cause a checksum mismatch
var corruptZipData = []byte{0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xed, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xb, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x55, 0x54, 0x9, 0x0, 0x3, 0x6d, 0x6c, 0x5f, 0x67, 0x6e, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd8, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x55, 0x54, 0x9, 0x0, 0x3, 0x48, 0x6c, 0x5f, 0x67, 0x58, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe3, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x4d, 0x61, 0x63, 0x4f, 0x53, 0x2f, 0x55, 0x54, 0x9, 0x0, 0x3, 0x59, 0x6c, 0x5f, 0x67, 0x9f, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe3, 0x7e, 0x8f, 0x59, 0xe3, 0x6, 0x15, 0x70, 0x14, 0x0, 0x0, 0x0, 0x14, 0x0, 0x0, 0x0, 0x20, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x4d, 0x61, 0x63, 0x4f, 0x53, 0x2f, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x55, 0x54, 0x9, 0x0, 0x3, 0x59, 0x6c, 0x5f, 0x67, 0x83, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x43, 0x4f, 0x52, 0x52, 0x55, 0x50, 0x54, 0xa, 0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe9, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x2f, 0x55, 0x54, 0x9, 0x0, 0x3, 0x66, 0x6c, 0x5f, 0x67, 0x83, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x3, 0x4, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe9, 0x7e, 0x8f, 0x59, 0x19, 0xa5, 0x62, 0xf7, 0x11, 0x0, 0x0, 0x0, 0x11, 0x0, 0x0, 0x0, 0x24, 0x0, 0x1c, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x2f, 0x6f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x55, 0x54, 0x9, 0x0, 0x3, 0x66, 0x6c, 0x5f, 0x67, 0x66, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x77, 0x6f, 0x75, 0x6c, 0x64, 0x20, 0x62, 0x65, 0x20, 0x74, 0x68, 0x65, 0x20, 0x63, 0x6c, 0x69, 0xa, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xed, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xb, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0xed, 0x41, 0x0, 0x0, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x55, 0x54, 0x5, 0x0, 0x3, 0x6d, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xd8, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x14, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0xed, 0x41, 0x45, 0x0, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x55, 0x54, 0x5, 0x0, 0x3, 0x48, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe3, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1a, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0xed, 0x41, 0x93, 0x0, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x4d, 0x61, 0x63, 0x4f, 0x53, 0x2f, 0x55, 0x54, 0x5, 0x0, 0x3, 0x59, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe3, 0x7e, 0x8f, 0x59, 0xe3, 0x6, 0x15, 0x70, 0x14, 0x0, 0x0, 0x0, 0x14, 0x0, 0x0, 0x0, 0x20, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa4, 0x81, 0xe7, 0x0, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x4d, 0x61, 0x63, 0x4f, 0x53, 0x2f, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x55, 0x54, 0x5, 0x0, 0x3, 0x59, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe9, 0x7e, 0x8f, 0x59, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1e, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x0, 0xed, 0x41, 0x55, 0x1, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x2f, 0x55, 0x54, 0x5, 0x0, 0x3, 0x66, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x1, 0x2, 0x1e, 0x3, 0xa, 0x0, 0x0, 0x0, 0x0, 0x0, 0xe9, 0x7e, 0x8f, 0x59, 0x19, 0xa5, 0x62, 0xf7, 0x11, 0x0, 0x0, 0x0, 0x11, 0x0, 0x0, 0x0, 0x24, 0x0, 0x18, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xa4, 0x81, 0xad, 0x1, 0x0, 0x0, 0x4f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x2e, 0x61, 0x70, 0x70, 0x2f, 0x43, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x73, 0x2f, 0x52, 0x65, 0x73, 0x6f, 0x75, 0x72, 0x63, 0x65, 0x73, 0x2f, 0x6f, 0x6c, 0x6c, 0x61, 0x6d, 0x61, 0x55, 0x54, 0x5, 0x0, 0x3, 0x66, 0x6c, 0x5f, 0x67, 0x75, 0x78, 0xb, 0x0, 0x1, 0x4, 0xf5, 0x1, 0x0, 0x0, 0x4, 0x14, 0x0, 0x0, 0x0, 0x50, 0x4b, 0x5, 0x6, 0x0, 0x0, 0x0, 0x0, 0x6, 0x0, 0x6, 0x0, 0x3f, 0x2, 0x0, 0x0, 0x1c, 0x2, 0x0, 0x0, 0x0, 0x0}
type testPayload struct {
Name string
Body []byte
Mode fs.FileMode
}
func zipCreationHelper(filename string, files []testPayload) error {
fd, err := os.Create(filename)
if err != nil {
return err
}
w := zip.NewWriter(fd)
for _, file := range files {
fh := &zip.FileHeader{
Name: file.Name,
Flags: 0,
}
if file.Mode != 0 {
fh.SetMode(file.Mode)
}
f, err := w.CreateHeader(fh)
if err != nil {
return err
}
_, err = f.Write(file.Body)
if err != nil {
return err
}
}
return w.Close()
}
func TestAlreadyMoved(t *testing.T) {
oldPath := SystemWidePath
defer func() {
SystemWidePath = oldPath
}()
exe, err := os.Executable()
if err != nil {
t.Fatal("failed to find executable path")
}
tmpDir := t.TempDir()
testApp := filepath.Join(tmpDir, "Ollama.app")
err = os.MkdirAll(filepath.Join(testApp, "Contents", "MacOS"), 0o755)
if err != nil {
t.Fatal("failed to create Contents dir")
}
SystemWidePath = testApp
testBinary := filepath.Join(testApp, "Contents", "MacOS", "Ollama")
if err := os.Symlink(exe, testBinary); err != nil {
t.Fatalf("failed to create symlink to executable: %s", err)
}
bundle := alreadyMoved()
if bundle != testApp {
t.Fatalf("expected %s, got %s", testApp, bundle)
}
// "Keep scenario"
testApp = filepath.Join(tmpDir, "Ollama 2.app")
err = os.MkdirAll(filepath.Join(testApp, "Contents", "MacOS"), 0o755)
if err != nil {
t.Fatal("failed to create Contents dir")
}
testBinary = filepath.Join(testApp, "Contents", "MacOS", "Ollama")
if err := os.Symlink(exe, testBinary); err != nil {
t.Fatalf("failed to create symlink to executable: %s", err)
}
bundle = alreadyMoved()
if bundle != testApp {
t.Fatalf("expected %s, got %s", testApp, bundle)
}
}
//go:build windows || darwin
package updater
import (
"archive/zip"
"bytes"
"context"
"fmt"
"io"
"log/slog"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/ollama/ollama/app/store"
)
func TestIsNewReleaseAvailable(t *testing.T) {
slog.SetLogLoggerLevel(slog.LevelDebug)
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
// TODO - wire up the redirects to mimic real behavior
} else {
slog.Debug("unexpected request", "url", r.URL)
}
}))
defer server.Close()
slog.Debug("server", "url", server.URL)
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
UpdateCheckURLBase = server.URL + "/update.json"
updatePresent, resp := updater.checkForUpdate(t.Context())
if !updatePresent {
t.Fatal("expected update to be available")
}
if resp.UpdateVersion != "9.9.9" {
t.Fatal("unexpected response", "url", resp.UpdateURL, "version", resp.UpdateVersion)
}
}
func TestBackgoundChecker(t *testing.T) {
UpdateStageDir = t.TempDir()
haveUpdate := false
verified := false
done := make(chan int)
cb := func(ver string) error {
haveUpdate = true
done <- 0
return nil
}
stallTimer := time.NewTimer(5 * time.Second)
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
UpdateCheckInitialDelay = 5 * time.Millisecond
UpdateCheckInterval = 5 * time.Millisecond
VerifyDownload = func() error {
verified = true
return nil
}
var server *httptest.Server
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/update.json" {
w.Write([]byte(
fmt.Sprintf(`{"version": "9.9.9", "url": "%s"}`,
server.URL+"/9.9.9/"+Installer)))
// TODO - wire up the redirects to mimic real behavior
} else if r.URL.Path == "/9.9.9/"+Installer {
buf := &bytes.Buffer{}
zw := zip.NewWriter(buf)
zw.Close()
io.Copy(w, buf)
} else {
slog.Debug("unexpected request", "url", r.URL)
}
}))
defer server.Close()
UpdateCheckURLBase = server.URL + "/update.json"
updater := &Updater{Store: &store.Store{}}
defer updater.Store.Close() // Ensure database is closed
updater.StartBackgroundUpdaterChecker(ctx, cb)
select {
case <-stallTimer.C:
t.Fatal("stalled")
case <-done:
if !haveUpdate {
t.Fatal("no update received")
}
if !verified {
t.Fatal("unverified")
}
}
}
package updater
import (
"errors"
"fmt"
"log/slog"
"os"
"os/exec"
"path"
"path/filepath"
"strings"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
var runningInstaller string
type OSVERSIONINFOEXW struct {
dwOSVersionInfoSize uint32
dwMajorVersion uint32
dwMinorVersion uint32
dwBuildNumber uint32
dwPlatformId uint32
szCSDVersion [128]uint16
wServicePackMajor uint16
wServicePackMinor uint16
wSuiteMask uint16
wProductType uint8
wReserved uint8
}
func init() {
VerifyDownload = verifyDownload
Installer = "Ollama-darwin.zip"
localAppData := os.Getenv("LOCALAPPDATA")
appDataDir := filepath.Join(localAppData, "Ollama")
// Use a distinct update staging directory from the old desktop app
// to avoid double upgrades on the transition
UpdateStageDir = filepath.Join(appDataDir, "updates_v2")
UpgradeLogFile = filepath.Join(appDataDir, "upgrade.log")
Installer = "OllamaSetup.exe"
runningInstaller = filepath.Join(appDataDir, Installer)
UpgradeMarkerFile = filepath.Join(appDataDir, "upgraded")
loadOSVersion()
}
func loadOSVersion() {
UserAgentOS = "Windows"
verInfo := OSVERSIONINFOEXW{}
verInfo.dwOSVersionInfoSize = (uint32)(unsafe.Sizeof(verInfo))
ntdll, err := windows.LoadDLL("ntdll.dll")
if err != nil {
slog.Warn("unable to find ntdll", "error", err)
return
}
defer ntdll.Release()
pRtlGetVersion, err := ntdll.FindProc("RtlGetVersion")
if err != nil {
slog.Warn("unable to locate RtlGetVersion", "error", err)
return
}
status, _, err := pRtlGetVersion.Call(uintptr(unsafe.Pointer(&verInfo)))
if status < 0x80000000 { // Success or Informational
// Note: Windows 11 reports 10.0.22000 or newer
UserAgentOS = fmt.Sprintf("Windows/%d.%d.%d", verInfo.dwMajorVersion, verInfo.dwMinorVersion, verInfo.dwBuildNumber)
} else {
slog.Warn("unable to get OS version", "error", err)
}
}
func getStagedUpdate() string {
// When transitioning from old to new app, cleanup the update from the old staging dir
// This can eventually be removed once enough time has passed since the transition
cleanupOldDownloads(filepath.Join(os.Getenv("LOCALAPPDATA"), "Ollama", "updates"))
files, err := filepath.Glob(filepath.Join(UpdateStageDir, "*", "*.exe"))
if err != nil {
slog.Debug("failed to lookup downloads", "error", err)
return ""
}
if len(files) == 0 {
return ""
} else if len(files) > 1 {
// Shouldn't happen
slog.Warn("multiple update downloads found, using first one", "bundles", files)
}
return files[0]
}
func DoUpgrade(interactive bool) error {
bundle := getStagedUpdate()
if bundle == "" {
return fmt.Errorf("failed to lookup downloads")
}
// We move the installer to ensure we don't race with multiple apps starting in quick succession
if err := os.Rename(bundle, runningInstaller); err != nil {
return fmt.Errorf("unable to rename %s -> %s : %w", bundle, runningInstaller, err)
}
slog.Info("upgrade log file " + UpgradeLogFile)
// make the upgrade show progress, but non interactive
installArgs := []string{
"/CLOSEAPPLICATIONS", // Quit the tray app if it's still running
"/LOG=" + filepath.Base(UpgradeLogFile), // Only relative seems reliable, so set pwd
"/FORCECLOSEAPPLICATIONS", // Force close the tray app - might be needed
"/SP", // Skip the "This will install... Do you wish to continue" prompt
"/NOCANCEL", // Disable the ability to cancel upgrade mid-flight to avoid partially installed upgrades
"/SILENT",
}
if !interactive {
// Add flags to make it totally silent without GUI
installArgs = append(installArgs, "/VERYSILENT", "/SUPPRESSMSGBOXES")
}
slog.Info("starting upgrade", "installer", runningInstaller, "args", installArgs)
os.Chdir(filepath.Dir(UpgradeLogFile)) //nolint:errcheck
cmd := exec.Command(runningInstaller, installArgs...)
if err := cmd.Start(); err != nil {
return fmt.Errorf("unable to start ollama app %w", err)
}
if cmd.Process != nil {
err := cmd.Process.Release()
if err != nil {
slog.Error(fmt.Sprintf("failed to release server process: %s", err))
}
} else {
// TODO - some details about why it didn't start, or is this a pedantic error case?
return errors.New("installer process did not start")
}
// If the install fails to upgrade the system, and leaves a functional
// app, this marker file will cause us to remove the staged upgrade
// bundle, which will prevent trying again until we download again.
// If this becomes looping a problem, we may need to look for failures
// in the upgrade log in DoPostUpgradeCleanup and then not download
// the same version again.
f, err := os.OpenFile(UpgradeMarkerFile, os.O_RDONLY|os.O_CREATE, 0o666)
if err != nil {
slog.Warn("unable to create marker file", "file", UpgradeMarkerFile, "error", err)
}
f.Close()
// TODO should we linger for a moment and check to make sure it's actually running by checking the pid?
slog.Info("Installer started in background, exiting")
os.Exit(0)
// Not reached
return nil
}
func DoPostUpgradeCleanup() error {
cleanupOldDownloads(UpdateStageDir)
err := os.Remove(UpgradeMarkerFile)
if err != nil {
slog.Warn("unable to clean up marker file", "marker", UpgradeMarkerFile, "error", err)
}
err = os.Remove(runningInstaller)
if err != nil {
slog.Debug("failed to remove running installer on first attempt, backgrounding...", "installer", runningInstaller, "error", err)
go func() {
for range 10 {
time.Sleep(5 * time.Second)
if err := os.Remove(runningInstaller); err == nil {
slog.Debug("installer cleaned up")
return
}
slog.Debug("failed to remove running installer on background attempt", "installer", runningInstaller, "error", err)
}
}()
}
return nil
}
func verifyDownload() error {
return nil
}
func IsUpdatePending() bool {
return getStagedUpdate() != ""
}
func DoUpgradeAtStartup() error {
return DoUpgrade(false)
}
func isInstallerRunning() bool {
return len(IsProcRunning(Installer)) > 0
}
func IsProcRunning(procName string) []uint32 {
pids := make([]uint32, 2048)
var ret uint32
if err := windows.EnumProcesses(pids, &ret); err != nil || ret == 0 {
slog.Debug("failed to check for running installers", "error", err)
return nil
}
pids = pids[:ret]
matches := []uint32{}
for _, pid := range pids {
if pid == 0 {
continue
}
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, pid)
if err != nil {
continue
}
defer windows.CloseHandle(hProcess)
var module windows.Handle
var cbNeeded uint32
cb := (uint32)(unsafe.Sizeof(module))
if err := windows.EnumProcessModules(hProcess, &module, cb, &cbNeeded); err != nil {
continue
}
var sz uint32 = 1024 * 8
moduleName := make([]uint16, sz)
cb = uint32(len(moduleName)) * (uint32)(unsafe.Sizeof(uint16(0)))
if err := windows.GetModuleBaseName(hProcess, module, &moduleName[0], cb); err != nil && err != syscall.ERROR_INSUFFICIENT_BUFFER {
continue
}
exeFile := path.Base(strings.ToLower(syscall.UTF16ToString(moduleName)))
if strings.EqualFold(exeFile, procName) {
matches = append(matches, pid)
}
}
return matches
}
//go:build windows || darwin
package updater
import (
"log/slog"
"testing"
)
func TestIsInstallerRunning(t *testing.T) {
slog.SetLogLoggerLevel(slog.LevelDebug)
Installer = "go.exe"
if !isInstallerRunning() {
t.Fatal("not running")
}
}
//go:build windows || darwin
package version
var Version string = "0.0.0"
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Webview
This is a vendored version of the [webview_go](https://github.com/webview/webview_go) project
`WebView2.h` is a vendored version of the Microsoft's [WebView2](https://learn.microsoft.com/en-us/microsoft-edge/webview2/get-started/win32)
This source diff could not be displayed because it is too large. You can view the blob instead.
#include "webview.h"
#include <stdlib.h>
#include <stdint.h>
struct binding_context {
webview_t w;
uintptr_t index;
};
void _webviewDispatchGoCallback(void *);
void _webviewBindingGoCallback(webview_t, char *, char *, uintptr_t);
static void _webview_dispatch_cb(webview_t w, void *arg) {
_webviewDispatchGoCallback(arg);
}
static void _webview_binding_cb(const char *id, const char *req, void *arg) {
struct binding_context *ctx = (struct binding_context *) arg;
_webviewBindingGoCallback(ctx->w, (char *)id, (char *)req, ctx->index);
}
void CgoWebViewDispatch(webview_t w, uintptr_t arg) {
webview_dispatch(w, _webview_dispatch_cb, (void *)arg);
}
void CgoWebViewBind(webview_t w, const char *name, uintptr_t index) {
struct binding_context *ctx = calloc(1, sizeof(struct binding_context));
ctx->w = w;
ctx->index = index;
webview_bind(w, name, _webview_binding_cb, (void *)ctx);
}
void CgoWebViewUnbind(webview_t w, const char *name) {
webview_unbind(w, name);
}
#include "webview.h"
\ No newline at end of file
//go:build windows || darwin
/*
* MIT License
*
* Copyright (c) 2017 Serge Zaitsev
* Copyright (c) 2022 Steffen André Langnes
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
package webview
/*
#cgo CFLAGS: -I${SRCDIR}/libs/webview/include
#cgo CXXFLAGS: -I${SRCDIR}/libs/webview/include -DWEBVIEW_STATIC
#cgo darwin CXXFLAGS: -DWEBVIEW_COCOA -std=c++11
#cgo darwin LDFLAGS: -framework WebKit -ldl
#cgo windows CXXFLAGS: -DWEBVIEW_EDGE -std=c++14 -I${SRCDIR}/libs/mswebview2/include
#cgo windows LDFLAGS: -static -ladvapi32 -lole32 -lshell32 -lshlwapi -luser32 -lversion
#include "webview.h"
#include <stdlib.h>
#include <stdint.h>
void CgoWebViewDispatch(webview_t w, uintptr_t arg);
void CgoWebViewBind(webview_t w, const char *name, uintptr_t index);
void CgoWebViewUnbind(webview_t w, const char *name);
void webview_set_zoom(webview_t w, double level);
double webview_get_zoom(webview_t w);
*/
import "C"
import (
"encoding/json"
"errors"
"reflect"
"runtime"
"sync"
"unsafe"
)
func init() {
// Ensure that main.main is called from the main thread
runtime.LockOSThread()
}
// Hints are used to configure window sizing and resizing
type Hint int
const (
// Width and height are default size
HintNone = C.WEBVIEW_HINT_NONE
// Window size can not be changed by a user
HintFixed = C.WEBVIEW_HINT_FIXED
// Width and height are minimum bounds
HintMin = C.WEBVIEW_HINT_MIN
// Width and height are maximum bounds
HintMax = C.WEBVIEW_HINT_MAX
)
type WebView interface {
// Run runs the main loop until it's terminated. After this function exits -
// you must destroy the webview.
Run()
// Terminate stops the main loop. It is safe to call this function from
// a background thread.
Terminate()
// Dispatch posts a function to be executed on the main thread. You normally
// do not need to call this function, unless you want to tweak the native
// window.
Dispatch(f func())
// Destroy destroys a webview and closes the native window.
Destroy()
// Window returns a native window handle pointer. When using GTK backend the
// pointer is GtkWindow pointer, when using Cocoa backend the pointer is
// NSWindow pointer, when using Win32 backend the pointer is HWND pointer.
Window() unsafe.Pointer
// SetTitle updates the title of the native window. Must be called from the UI
// thread.
SetTitle(title string)
// SetSize updates native window size. See Hint constants.
SetSize(w int, h int, hint Hint)
// Navigate navigates webview to the given URL. URL may be a properly encoded data.
// URI. Examples:
// w.Navigate("https://github.com/webview/webview")
// w.Navigate("data:text/html,%3Ch1%3EHello%3C%2Fh1%3E")
// w.Navigate("data:text/html;base64,PGgxPkhlbGxvPC9oMT4=")
Navigate(url string)
// SetHtml sets the webview HTML directly.
// Example: w.SetHtml(w, "<h1>Hello</h1>");
SetHtml(html string)
// Init injects JavaScript code at the initialization of the new page. Every
// time the webview will open a the new page - this initialization code will
// be executed. It is guaranteed that code is executed before window.onload.
Init(js string)
// Eval evaluates arbitrary JavaScript code. Evaluation happens asynchronously,
// also the result of the expression is ignored. Use RPC bindings if you want
// to receive notifications about the results of the evaluation.
Eval(js string)
// Bind binds a callback function so that it will appear under the given name
// as a global JavaScript function. Internally it uses webview_init().
// Callback receives a request string and a user-provided argument pointer.
// Request string is a JSON array of all the arguments passed to the
// JavaScript function.
//
// f must be a function
// f must return either value and error or just error
Bind(name string, f interface{}) error
// Removes a callback that was previously set by Bind.
Unbind(name string) error
// SetZoom sets the zoom level of the webview.
// level: 1.0 is normal size, >1.0 zooms in, <1.0 zooms out.
SetZoom(level float64)
// GetZoom returns the current zoom level of the webview.
GetZoom() float64
}
type webview struct {
w C.webview_t
}
var (
m sync.Mutex
index uintptr
dispatch = map[uintptr]func(){}
bindings = map[uintptr]func(id, req string) (interface{}, error){}
)
func boolToInt(b bool) C.int {
if b {
return 1
}
return 0
}
// New calls NewWindow to create a new window and a new webview instance. If debug
// is non-zero - developer tools will be enabled (if the platform supports them).
func New(debug bool) WebView { return NewWindow(debug, nil) }
// NewWindow creates a new webview instance. If debug is non-zero - developer
// tools will be enabled (if the platform supports them). Window parameter can be
// a pointer to the native window handle. If it's non-null - then child WebView is
// embedded into the given parent window. Otherwise a new window is created.
// Depending on the platform, a GtkWindow, NSWindow or HWND pointer can be passed
// here.
func NewWindow(debug bool, window unsafe.Pointer) WebView {
w := &webview{}
w.w = C.webview_create(boolToInt(debug), window)
return w
}
func (w *webview) Destroy() {
C.webview_destroy(w.w)
}
func (w *webview) Run() {
C.webview_run(w.w)
}
func (w *webview) Terminate() {
C.webview_terminate(w.w)
}
func (w *webview) Window() unsafe.Pointer {
return C.webview_get_window(w.w)
}
func (w *webview) Navigate(url string) {
s := C.CString(url)
defer C.free(unsafe.Pointer(s))
C.webview_navigate(w.w, s)
}
func (w *webview) SetHtml(html string) {
s := C.CString(html)
defer C.free(unsafe.Pointer(s))
C.webview_set_html(w.w, s)
}
func (w *webview) SetTitle(title string) {
s := C.CString(title)
defer C.free(unsafe.Pointer(s))
C.webview_set_title(w.w, s)
}
func (w *webview) SetSize(width int, height int, hint Hint) {
C.webview_set_size(w.w, C.int(width), C.int(height), C.webview_hint_t(hint))
}
func (w *webview) Init(js string) {
s := C.CString(js)
defer C.free(unsafe.Pointer(s))
C.webview_init(w.w, s)
}
func (w *webview) Eval(js string) {
s := C.CString(js)
defer C.free(unsafe.Pointer(s))
C.webview_eval(w.w, s)
}
func (w *webview) Dispatch(f func()) {
m.Lock()
for ; dispatch[index] != nil; index++ {
}
dispatch[index] = f
m.Unlock()
C.CgoWebViewDispatch(w.w, C.uintptr_t(index))
}
//export _webviewDispatchGoCallback
func _webviewDispatchGoCallback(index unsafe.Pointer) {
m.Lock()
f := dispatch[uintptr(index)]
delete(dispatch, uintptr(index))
m.Unlock()
f()
}
//export _webviewBindingGoCallback
func _webviewBindingGoCallback(w C.webview_t, id *C.char, req *C.char, index uintptr) {
m.Lock()
f := bindings[index]
m.Unlock()
jsString := func(v interface{}) string { b, _ := json.Marshal(v); return string(b) }
status := 0
var result string
if res, err := f(C.GoString(id), C.GoString(req)); err != nil {
status = -1
result = jsString(err.Error())
} else if b, err := json.Marshal(res); err != nil {
status = -1
result = jsString(err.Error())
} else {
status = 0
result = string(b)
}
s := C.CString(result)
defer C.free(unsafe.Pointer(s))
C.webview_return(w, id, C.int(status), s)
}
func (w *webview) Bind(name string, f interface{}) error {
v := reflect.ValueOf(f)
// f must be a function
if v.Kind() != reflect.Func {
return errors.New("only functions can be bound")
}
// f must return either value and error or just error
if n := v.Type().NumOut(); n > 2 {
return errors.New("function may only return a value or a value+error")
}
binding := func(id, req string) (interface{}, error) {
raw := []json.RawMessage{}
if err := json.Unmarshal([]byte(req), &raw); err != nil {
return nil, err
}
isVariadic := v.Type().IsVariadic()
numIn := v.Type().NumIn()
if (isVariadic && len(raw) < numIn-1) || (!isVariadic && len(raw) != numIn) {
return nil, errors.New("function arguments mismatch")
}
args := []reflect.Value{}
for i := range raw {
var arg reflect.Value
if isVariadic && i >= numIn-1 {
arg = reflect.New(v.Type().In(numIn - 1).Elem())
} else {
arg = reflect.New(v.Type().In(i))
}
if err := json.Unmarshal(raw[i], arg.Interface()); err != nil {
return nil, err
}
args = append(args, arg.Elem())
}
errorType := reflect.TypeOf((*error)(nil)).Elem()
res := v.Call(args)
switch len(res) {
case 0:
// No results from the function, just return nil
return nil, nil
case 1:
// One result may be a value, or an error
if res[0].Type().Implements(errorType) {
if res[0].Interface() != nil {
return nil, res[0].Interface().(error)
}
return nil, nil
}
return res[0].Interface(), nil
case 2:
// Two results: first one is value, second is error
if !res[1].Type().Implements(errorType) {
return nil, errors.New("second return value must be an error")
}
if res[1].Interface() == nil {
return res[0].Interface(), nil
}
return res[0].Interface(), res[1].Interface().(error)
default:
return nil, errors.New("unexpected number of return values")
}
}
m.Lock()
for ; bindings[index] != nil; index++ {
}
bindings[index] = binding
m.Unlock()
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.CgoWebViewBind(w.w, cname, C.uintptr_t(index))
return nil
}
func (w *webview) Unbind(name string) error {
cname := C.CString(name)
defer C.free(unsafe.Pointer(cname))
C.CgoWebViewUnbind(w.w, cname)
return nil
}
func (w *webview) SetZoom(level float64) {
C.webview_set_zoom(w.w, C.double(level))
}
func (w *webview) GetZoom() float64 {
return float64(C.webview_get_zoom(w.w))
}
/*
* MIT License
*
* Copyright (c) 2017 Serge Zaitsev
* Copyright (c) 2022 Steffen André Langnes
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
/// @file webview.h
#ifndef WEBVIEW_H
#define WEBVIEW_H
/**
* Used to specify function linkage such as extern, inline, etc.
*
* When @c WEBVIEW_API is not already defined, the defaults are as follows:
*
* - @c inline when compiling C++ code.
* - @c extern when compiling C code.
*
* The following macros can be used to automatically set an appropriate
* value for @c WEBVIEW_API:
*
* - Define @c WEBVIEW_BUILD_SHARED when building a shared library.
* - Define @c WEBVIEW_SHARED when using a shared library.
* - Define @c WEBVIEW_STATIC when building or using a static library.
*/
#ifndef WEBVIEW_API
#if defined(WEBVIEW_SHARED) || defined(WEBVIEW_BUILD_SHARED)
#if defined(_WIN32) || defined(__CYGWIN__)
#if defined(WEBVIEW_BUILD_SHARED)
#define WEBVIEW_API __declspec(dllexport)
#else
#define WEBVIEW_API __declspec(dllimport)
#endif
#else
#define WEBVIEW_API __attribute__((visibility("default")))
#endif
#elif !defined(WEBVIEW_STATIC) && defined(__cplusplus)
#define WEBVIEW_API inline
#else
#define WEBVIEW_API extern
#endif
#endif
/// @name Version
/// @{
#ifndef WEBVIEW_VERSION_MAJOR
/// The current library major version.
#define WEBVIEW_VERSION_MAJOR 0
#endif
#ifndef WEBVIEW_VERSION_MINOR
/// The current library minor version.
#define WEBVIEW_VERSION_MINOR 11
#endif
#ifndef WEBVIEW_VERSION_PATCH
/// The current library patch version.
#define WEBVIEW_VERSION_PATCH 0
#endif
#ifndef WEBVIEW_VERSION_PRE_RELEASE
/// SemVer 2.0.0 pre-release labels prefixed with "-".
#define WEBVIEW_VERSION_PRE_RELEASE ""
#endif
#ifndef WEBVIEW_VERSION_BUILD_METADATA
/// SemVer 2.0.0 build metadata prefixed with "+".
#define WEBVIEW_VERSION_BUILD_METADATA ""
#endif
/// @}
/// @name Used internally
/// @{
/// Utility macro for stringifying a macro argument.
#define WEBVIEW_STRINGIFY(x) #x
/// Utility macro for stringifying the result of a macro argument expansion.
#define WEBVIEW_EXPAND_AND_STRINGIFY(x) WEBVIEW_STRINGIFY(x)
/// @}
/// @name Version
/// @{
/// SemVer 2.0.0 version number in MAJOR.MINOR.PATCH format.
#define WEBVIEW_VERSION_NUMBER \
WEBVIEW_EXPAND_AND_STRINGIFY(WEBVIEW_VERSION_MAJOR) \
"." WEBVIEW_EXPAND_AND_STRINGIFY( \
WEBVIEW_VERSION_MINOR) "." WEBVIEW_EXPAND_AND_STRINGIFY(WEBVIEW_VERSION_PATCH)
/// @}
/// Holds the elements of a MAJOR.MINOR.PATCH version number.
typedef struct {
/// Major version.
unsigned int major;
/// Minor version.
unsigned int minor;
/// Patch version.
unsigned int patch;
} webview_version_t;
/// Holds the library's version information.
typedef struct {
/// The elements of the version number.
webview_version_t version;
/// SemVer 2.0.0 version number in MAJOR.MINOR.PATCH format.
char version_number[32];
/// SemVer 2.0.0 pre-release labels prefixed with "-" if specified, otherwise
/// an empty string.
char pre_release[48];
/// SemVer 2.0.0 build metadata prefixed with "+", otherwise an empty string.
char build_metadata[48];
} webview_version_info_t;
/// Pointer to a webview instance.
typedef void *webview_t;
/// Native handle kind. The actual type depends on the backend.
typedef enum {
/// Top-level window. @c GtkWindow pointer (GTK), @c NSWindow pointer (Cocoa)
/// or @c HWND (Win32).
WEBVIEW_NATIVE_HANDLE_KIND_UI_WINDOW,
/// Browser widget. @c GtkWidget pointer (GTK), @c NSView pointer (Cocoa) or
/// @c HWND (Win32).
WEBVIEW_NATIVE_HANDLE_KIND_UI_WIDGET,
/// Browser controller. @c WebKitWebView pointer (WebKitGTK), @c WKWebView
/// pointer (Cocoa/WebKit) or @c ICoreWebView2Controller pointer
/// (Win32/WebView2).
WEBVIEW_NATIVE_HANDLE_KIND_BROWSER_CONTROLLER
} webview_native_handle_kind_t;
/// Window size hints
typedef enum {
/// Width and height are default size.
WEBVIEW_HINT_NONE,
/// Width and height are minimum bounds.
WEBVIEW_HINT_MIN,
/// Width and height are maximum bounds.
WEBVIEW_HINT_MAX,
/// Window size can not be changed by a user.
WEBVIEW_HINT_FIXED
} webview_hint_t;
#ifdef __cplusplus
extern "C" {
#endif
/**
* Creates a new webview instance.
*
* @param debug Enable developer tools if supported by the backend.
* @param window Optional native window handle, i.e. @c GtkWindow pointer
* @c NSWindow pointer (Cocoa) or @c HWND (Win32). If non-null,
* the webview widget is embedded into the given window, and the
* caller is expected to assume responsibility for the window as
* well as application lifecycle. If the window handle is null,
* a new window is created and both the window and application
* lifecycle are managed by the webview instance.
* @remark Win32: The function also accepts a pointer to @c HWND (Win32) in the
* window parameter for backward compatibility.
* @remark Win32/WebView2: @c CoInitializeEx should be called with
* @c COINIT_APARTMENTTHREADED before attempting to call this function
* with an existing window. Omitting this step may cause WebView2
* initialization to fail.
* @return @c NULL on failure. Creation can fail for various reasons such
* as when required runtime dependencies are missing or when window
* creation fails.
*/
WEBVIEW_API webview_t webview_create(int debug, void *window);
/**
* Destroys a webview instance and closes the native window.
*
* @param w The webview instance.
*/
WEBVIEW_API void webview_destroy(webview_t w);
/**
* Runs the main loop until it's terminated.
*
* @param w The webview instance.
*/
WEBVIEW_API void webview_run(webview_t w);
/**
* Stops the main loop. It is safe to call this function from another other
* background thread.
*
* @param w The webview instance.
*/
WEBVIEW_API void webview_terminate(webview_t w);
/**
* Schedules a function to be invoked on the thread with the run/event loop.
* Use this function e.g. to interact with the library or native handles.
*
* @param w The webview instance.
* @param fn The function to be invoked.
* @param arg An optional argument passed along to the callback function.
*/
WEBVIEW_API void
webview_dispatch(webview_t w, void (*fn)(webview_t w, void *arg), void *arg);
/**
* Returns the native handle of the window associated with the webview instance.
* The handle can be a @c GtkWindow pointer (GTK), @c NSWindow pointer (Cocoa)
* or @c HWND (Win32).
*
* @param w The webview instance.
* @return The handle of the native window.
*/
WEBVIEW_API void *webview_get_window(webview_t w);
/**
* Get a native handle of choice.
*
* @param w The webview instance.
* @param kind The kind of handle to retrieve.
* @return The native handle or @c NULL.
* @since 0.11
*/
WEBVIEW_API void *webview_get_native_handle(webview_t w,
webview_native_handle_kind_t kind);
/**
* Updates the title of the native window.
*
* @param w The webview instance.
* @param title The new title.
*/
WEBVIEW_API void webview_set_title(webview_t w, const char *title);
/**
* Updates the size of the native window.
*
* @param w The webview instance.
* @param width New width.
* @param height New height.
* @param hints Size hints.
*/
WEBVIEW_API void webview_set_size(webview_t w, int width, int height,
webview_hint_t hints);
/**
* Navigates webview to the given URL. URL may be a properly encoded data URI.
*
* Example:
* @code{.c}
* webview_navigate(w, "https://github.com/webview/webview");
* webview_navigate(w, "data:text/html,%3Ch1%3EHello%3C%2Fh1%3E");
* webview_navigate(w, "data:text/html;base64,PGgxPkhlbGxvPC9oMT4=");
* @endcode
*
* @param w The webview instance.
* @param url URL.
*/
WEBVIEW_API void webview_navigate(webview_t w, const char *url);
/**
* Load HTML content into the webview.
*
* Example:
* @code{.c}
* webview_set_html(w, "<h1>Hello</h1>");
* @endcode
*
* @param w The webview instance.
* @param html HTML content.
*/
WEBVIEW_API void webview_set_html(webview_t w, const char *html);
/**
* Injects JavaScript code to be executed immediately upon loading a page.
* The code will be executed before @c window.onload.
*
* @param w The webview instance.
* @param js JS content.
*/
WEBVIEW_API void webview_init(webview_t w, const char *js);
/**
* Evaluates arbitrary JavaScript code.
*
* Use bindings if you need to communicate the result of the evaluation.
*
* @param w The webview instance.
* @param js JS content.
*/
WEBVIEW_API void webview_eval(webview_t w, const char *js);
/**
* Binds a function pointer to a new global JavaScript function.
*
* Internally, JS glue code is injected to create the JS function by the
* given name. The callback function is passed a sequential request
* identifier, a request string and a user-provided argument. The request
* string is a JSON array of the arguments passed to the JS function.
*
* @param w The webview instance.
* @param name Name of the JS function.
* @param fn Callback function.
* @param arg User argument.
*/
WEBVIEW_API void webview_bind(webview_t w, const char *name,
void (*fn)(const char *seq, const char *req,
void *arg),
void *arg);
/**
* Removes a binding created with webview_bind().
*
* @param w The webview instance.
* @param name Name of the binding.
*/
WEBVIEW_API void webview_unbind(webview_t w, const char *name);
/**
* Responds to a binding call from the JS side.
*
* @param w The webview instance.
* @param seq The sequence number of the binding call. Pass along the value
* received in the binding handler (see webview_bind()).
* @param status A status of zero tells the JS side that the binding call was
* succesful; any other value indicates an error.
* @param result The result of the binding call to be returned to the JS side.
* This must either be a valid JSON value or an empty string for
* the primitive JS value @c undefined.
*/
WEBVIEW_API void webview_return(webview_t w, const char *seq, int status,
const char *result);
/**
* Get the library's version information.
*
* @since 0.10
*/
WEBVIEW_API const webview_version_info_t *webview_version(void);
/**
* Set the zoom level of the webview.
*
* @param w The webview instance.
* @param level The zoom level. 1.0 is normal size, >1.0 zooms in, <1.0 zooms out.
*/
WEBVIEW_API void webview_set_zoom(webview_t w, double level);
/**
* Get the current zoom level of the webview.
*
* @param w The webview instance.
* @return The current zoom level.
*/
WEBVIEW_API double webview_get_zoom(webview_t w);
// TODO (jmorganca): these forward declarations should be
// in a header file but due to linking issues they live here for now
typedef struct
{
char *label;
int enabled;
int separator;
} menuItem;
int menu_get_item_count();
void *menu_get_items();
void menu_handle_selection(char *item);
#ifdef __cplusplus
}
#ifndef WEBVIEW_HEADER
#if !defined(WEBVIEW_GTK) && !defined(WEBVIEW_COCOA) && !defined(WEBVIEW_EDGE)
#if defined(__APPLE__)
#define WEBVIEW_COCOA
#elif defined(__unix__)
#define WEBVIEW_GTK
#elif defined(_WIN32)
#define WEBVIEW_EDGE
#else
#error "please, specify webview backend"
#endif
#endif
#ifndef WEBVIEW_DEPRECATED
#if __cplusplus >= 201402L
#define WEBVIEW_DEPRECATED(reason) [[deprecated(reason)]]
#elif defined(_MSC_VER)
#define WEBVIEW_DEPRECATED(reason) __declspec(deprecated(reason))
#else
#define WEBVIEW_DEPRECATED(reason) __attribute__((deprecated(reason)))
#endif
#endif
#ifndef WEBVIEW_DEPRECATED_PRIVATE
#define WEBVIEW_DEPRECATED_PRIVATE \
WEBVIEW_DEPRECATED("Private API should not be used")
#endif
#include <algorithm>
#include <array>
#include <atomic>
#include <cassert>
#include <cstdint>
#include <functional>
#include <future>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include <cstring>
#if defined(_WIN32)
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#else
#include <dlfcn.h>
#endif
namespace webview {
using dispatch_fn_t = std::function<void()>;
namespace detail {
// The library's version information.
constexpr const webview_version_info_t library_version_info{
{WEBVIEW_VERSION_MAJOR, WEBVIEW_VERSION_MINOR, WEBVIEW_VERSION_PATCH},
WEBVIEW_VERSION_NUMBER,
WEBVIEW_VERSION_PRE_RELEASE,
WEBVIEW_VERSION_BUILD_METADATA};
#if defined(_WIN32)
// Converts a narrow (UTF-8-encoded) string into a wide (UTF-16-encoded) string.
inline std::wstring widen_string(const std::string &input) {
if (input.empty()) {
return std::wstring();
}
UINT cp = CP_UTF8;
DWORD flags = MB_ERR_INVALID_CHARS;
auto input_c = input.c_str();
auto input_length = static_cast<int>(input.size());
auto required_length =
MultiByteToWideChar(cp, flags, input_c, input_length, nullptr, 0);
if (required_length > 0) {
std::wstring output(static_cast<std::size_t>(required_length), L'\0');
if (MultiByteToWideChar(cp, flags, input_c, input_length, &output[0],
required_length) > 0) {
return output;
}
}
// Failed to convert string from UTF-8 to UTF-16
return std::wstring();
}
// Converts a wide (UTF-16-encoded) string into a narrow (UTF-8-encoded) string.
inline std::string narrow_string(const std::wstring &input) {
struct wc_flags {
enum TYPE : unsigned int {
// WC_ERR_INVALID_CHARS
err_invalid_chars = 0x00000080U
};
};
if (input.empty()) {
return std::string();
}
UINT cp = CP_UTF8;
DWORD flags = wc_flags::err_invalid_chars;
auto input_c = input.c_str();
auto input_length = static_cast<int>(input.size());
auto required_length = WideCharToMultiByte(cp, flags, input_c, input_length,
nullptr, 0, nullptr, nullptr);
if (required_length > 0) {
std::string output(static_cast<std::size_t>(required_length), '\0');
if (WideCharToMultiByte(cp, flags, input_c, input_length, &output[0],
required_length, nullptr, nullptr) > 0) {
return output;
}
}
// Failed to convert string from UTF-16 to UTF-8
return std::string();
}
#endif
inline int json_parse_c(const char *s, size_t sz, const char *key, size_t keysz,
const char **value, size_t *valuesz) {
enum {
JSON_STATE_VALUE,
JSON_STATE_LITERAL,
JSON_STATE_STRING,
JSON_STATE_ESCAPE,
JSON_STATE_UTF8
} state = JSON_STATE_VALUE;
const char *k = nullptr;
int index = 1;
int depth = 0;
int utf8_bytes = 0;
*value = nullptr;
*valuesz = 0;
if (key == nullptr) {
index = static_cast<decltype(index)>(keysz);
if (index < 0) {
return -1;
}
keysz = 0;
}
for (; sz > 0; s++, sz--) {
enum {
JSON_ACTION_NONE,
JSON_ACTION_START,
JSON_ACTION_END,
JSON_ACTION_START_STRUCT,
JSON_ACTION_END_STRUCT
} action = JSON_ACTION_NONE;
auto c = static_cast<unsigned char>(*s);
switch (state) {
case JSON_STATE_VALUE:
if (c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == ',' ||
c == ':') {
continue;
} else if (c == '"') {
action = JSON_ACTION_START;
state = JSON_STATE_STRING;
} else if (c == '{' || c == '[') {
action = JSON_ACTION_START_STRUCT;
} else if (c == '}' || c == ']') {
action = JSON_ACTION_END_STRUCT;
} else if (c == 't' || c == 'f' || c == 'n' || c == '-' ||
(c >= '0' && c <= '9')) {
action = JSON_ACTION_START;
state = JSON_STATE_LITERAL;
} else {
return -1;
}
break;
case JSON_STATE_LITERAL:
if (c == ' ' || c == '\t' || c == '\n' || c == '\r' || c == ',' ||
c == ']' || c == '}' || c == ':') {
state = JSON_STATE_VALUE;
s--;
sz++;
action = JSON_ACTION_END;
} else if (c < 32 || c > 126) {
return -1;
} // fallthrough
case JSON_STATE_STRING:
if (c < 32 || (c > 126 && c < 192)) {
return -1;
} else if (c == '"') {
action = JSON_ACTION_END;
state = JSON_STATE_VALUE;
} else if (c == '\\') {
state = JSON_STATE_ESCAPE;
} else if (c >= 192 && c < 224) {
utf8_bytes = 1;
state = JSON_STATE_UTF8;
} else if (c >= 224 && c < 240) {
utf8_bytes = 2;
state = JSON_STATE_UTF8;
} else if (c >= 240 && c < 247) {
utf8_bytes = 3;
state = JSON_STATE_UTF8;
} else if (c >= 128 && c < 192) {
return -1;
}
break;
case JSON_STATE_ESCAPE:
if (c == '"' || c == '\\' || c == '/' || c == 'b' || c == 'f' ||
c == 'n' || c == 'r' || c == 't' || c == 'u') {
state = JSON_STATE_STRING;
} else {
return -1;
}
break;
case JSON_STATE_UTF8:
if (c < 128 || c > 191) {
return -1;
}
utf8_bytes--;
if (utf8_bytes == 0) {
state = JSON_STATE_STRING;
}
break;
default:
return -1;
}
if (action == JSON_ACTION_END_STRUCT) {
depth--;
}
if (depth == 1) {
if (action == JSON_ACTION_START || action == JSON_ACTION_START_STRUCT) {
if (index == 0) {
*value = s;
} else if (keysz > 0 && index == 1) {
k = s;
} else {
index--;
}
} else if (action == JSON_ACTION_END ||
action == JSON_ACTION_END_STRUCT) {
if (*value != nullptr && index == 0) {
*valuesz = (size_t)(s + 1 - *value);
return 0;
} else if (keysz > 0 && k != nullptr) {
if (keysz == (size_t)(s - k - 1) && memcmp(key, k + 1, keysz) == 0) {
index = 0;
} else {
index = 2;
}
k = nullptr;
}
}
}
if (action == JSON_ACTION_START_STRUCT) {
depth++;
}
}
return -1;
}
constexpr bool is_json_special_char(char c) {
return c == '"' || c == '\\' || c == '\b' || c == '\f' || c == '\n' ||
c == '\r' || c == '\t';
}
constexpr bool is_ascii_control_char(char c) { return c >= 0 && c <= 0x1f; }
inline std::string json_escape(const std::string &s, bool add_quotes = true) {
// Calculate the size of the resulting string.
// Add space for the double quotes.
size_t required_length = add_quotes ? 2 : 0;
for (auto c : s) {
if (is_json_special_char(c)) {
// '\' and a single following character
required_length += 2;
continue;
}
if (is_ascii_control_char(c)) {
// '\', 'u', 4 digits
required_length += 6;
continue;
}
++required_length;
}
// Allocate memory for resulting string only once.
std::string result;
result.reserve(required_length);
if (add_quotes) {
result += '"';
}
// Copy string while escaping characters.
for (auto c : s) {
if (is_json_special_char(c)) {
static constexpr char special_escape_table[256] =
"\0\0\0\0\0\0\0\0btn\0fr\0\0"
"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
"\0\0\"\0\0\0\0\0\0\0\0\0\0\0\0\0"
"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
"\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"
"\0\0\0\0\0\0\0\0\0\0\0\0\\";
result += '\\';
// NOLINTNEXTLINE(cppcoreguidelines-pro-bounds-constant-array-index)
result += special_escape_table[static_cast<unsigned char>(c)];
continue;
}
if (is_ascii_control_char(c)) {
// Escape as \u00xx
static constexpr char hex_alphabet[]{"0123456789abcdef"};
auto uc = static_cast<unsigned char>(c);
auto h = (uc >> 4) & 0x0f;
auto l = uc & 0x0f;
result += "\\u00";
// NOLINTBEGIN(cppcoreguidelines-pro-bounds-constant-array-index)
result += hex_alphabet[h];
result += hex_alphabet[l];
// NOLINTEND(cppcoreguidelines-pro-bounds-constant-array-index)
continue;
}
result += c;
}
if (add_quotes) {
result += '"';
}
// Should have calculated the exact amount of memory needed
assert(required_length == result.size());
return result;
}
inline int json_unescape(const char *s, size_t n, char *out) {
int r = 0;
if (*s++ != '"') {
return -1;
}
while (n > 2) {
char c = *s;
if (c == '\\') {
s++;
n--;
switch (*s) {
case 'b':
c = '\b';
break;
case 'f':
c = '\f';
break;
case 'n':
c = '\n';
break;
case 'r':
c = '\r';
break;
case 't':
c = '\t';
break;
case '\\':
c = '\\';
break;
case '/':
c = '/';
break;
case '\"':
c = '\"';
break;
default: // TODO: support unicode decoding
return -1;
}
}
if (out != nullptr) {
*out++ = c;
}
s++;
n--;
r++;
}
if (*s != '"') {
return -1;
}
if (out != nullptr) {
*out = '\0';
}
return r;
}
inline std::string json_parse(const std::string &s, const std::string &key,
const int index) {
const char *value;
size_t value_sz;
if (key.empty()) {
json_parse_c(s.c_str(), s.length(), nullptr, index, &value, &value_sz);
} else {
json_parse_c(s.c_str(), s.length(), key.c_str(), key.length(), &value,
&value_sz);
}
if (value != nullptr) {
if (value[0] != '"') {
return {value, value_sz};
}
int n = json_unescape(value, value_sz, nullptr);
if (n > 0) {
char *decoded = new char[n + 1];
json_unescape(value, value_sz, decoded);
std::string result(decoded, n);
delete[] decoded;
return result;
}
}
return "";
}
// Holds a symbol name and associated type for code clarity.
template <typename T> class library_symbol {
public:
using type = T;
constexpr explicit library_symbol(const char *name) : m_name(name) {}
constexpr const char *get_name() const { return m_name; }
private:
const char *m_name;
};
// Loads a native shared library and allows one to get addresses for those
// symbols.
class native_library {
public:
native_library() = default;
explicit native_library(const std::string &name)
: m_handle{load_library(name)} {}
#ifdef _WIN32
explicit native_library(const std::wstring &name)
: m_handle{load_library(name)} {}
#endif
~native_library() {
if (m_handle) {
#ifdef _WIN32
FreeLibrary(m_handle);
#else
dlclose(m_handle);
#endif
m_handle = nullptr;
}
}
native_library(const native_library &other) = delete;
native_library &operator=(const native_library &other) = delete;
native_library(native_library &&other) noexcept { *this = std::move(other); }
native_library &operator=(native_library &&other) noexcept {
if (this == &other) {
return *this;
}
m_handle = other.m_handle;
other.m_handle = nullptr;
return *this;
}
// Returns true if the library is currently loaded; otherwise false.
operator bool() const { return is_loaded(); }
// Get the address for the specified symbol or nullptr if not found.
template <typename Symbol>
typename Symbol::type get(const Symbol &symbol) const {
if (is_loaded()) {
// NOLINTBEGIN(cppcoreguidelines-pro-type-reinterpret-cast)
#ifdef _WIN32
#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wcast-function-type"
#endif
return reinterpret_cast<typename Symbol::type>(
GetProcAddress(m_handle, symbol.get_name()));
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif
#else
return reinterpret_cast<typename Symbol::type>(
dlsym(m_handle, symbol.get_name()));
#endif
// NOLINTEND(cppcoreguidelines-pro-type-reinterpret-cast)
}
return nullptr;
}
// Returns true if the library is currently loaded; otherwise false.
bool is_loaded() const { return !!m_handle; }
void detach() { m_handle = nullptr; }
// Returns true if the library by the given name is currently loaded; otherwise false.
static inline bool is_loaded(const std::string &name) {
#ifdef _WIN32
auto handle = GetModuleHandleW(widen_string(name).c_str());
#else
auto handle = dlopen(name.c_str(), RTLD_NOW | RTLD_NOLOAD);
if (handle) {
dlclose(handle);
}
#endif
return !!handle;
}
private:
#ifdef _WIN32
using mod_handle_t = HMODULE;
#else
using mod_handle_t = void *;
#endif
static inline mod_handle_t load_library(const std::string &name) {
#ifdef _WIN32
return load_library(widen_string(name));
#else
return dlopen(name.c_str(), RTLD_NOW);
#endif
}
#ifdef _WIN32
static inline mod_handle_t load_library(const std::wstring &name) {
return LoadLibraryW(name.c_str());
}
#endif
mod_handle_t m_handle{};
};
class engine_base {
public:
virtual ~engine_base() = default;
void navigate(const std::string &url) {
if (url.empty()) {
navigate_impl("about:blank");
return;
}
navigate_impl(url);
}
using binding_t = std::function<void(std::string, std::string, void *)>;
class binding_ctx_t {
public:
binding_ctx_t(binding_t callback, void *arg)
: callback(callback), arg(arg) {}
// This function is called upon execution of the bound JS function
binding_t callback;
// This user-supplied argument is passed to the callback
void *arg;
};
using sync_binding_t = std::function<std::string(std::string)>;
// Synchronous bind
void bind(const std::string &name, sync_binding_t fn) {
auto wrapper = [this, fn](const std::string &seq, const std::string &req,
void * /*arg*/) { resolve(seq, 0, fn(req)); };
bind(name, wrapper, nullptr);
}
// Asynchronous bind
void bind(const std::string &name, binding_t fn, void *arg) {
// NOLINTNEXTLINE(readability-container-contains): contains() requires C++20
if (bindings.count(name) > 0) {
return;
}
bindings.emplace(name, binding_ctx_t(fn, arg));
auto js = "(function() { var name = '" + name + "';" + R""(
var RPC = window._rpc = (window._rpc || {nextSeq: 1});
window[name] = function() {
var seq = RPC.nextSeq++;
var promise = new Promise(function(resolve, reject) {
RPC[seq] = {
resolve: resolve,
reject: reject,
};
});
window.external.invoke(JSON.stringify({
id: seq,
method: name,
params: Array.prototype.slice.call(arguments),
}));
return promise;
}
})())"";
init(js);
eval(js);
}
void unbind(const std::string &name) {
auto found = bindings.find(name);
if (found != bindings.end()) {
auto js = "delete window['" + name + "'];";
init(js);
eval(js);
bindings.erase(found);
}
}
void resolve(const std::string &seq, int status, const std::string &result) {
// NOLINTNEXTLINE(modernize-avoid-bind): Lambda with move requires C++14
dispatch(std::bind(
[seq, status, this](std::string escaped_result) {
std::string js;
js += "(function(){var seq = \"";
js += seq;
js += "\";\n";
js += "var status = ";
js += std::to_string(status);
js += ";\n";
js += "var result = ";
js += escaped_result;
js += ";\
var promise = window._rpc[seq];\
delete window._rpc[seq];\
if (result !== undefined) {\
try {\
result = JSON.parse(result);\
} catch {\
promise.reject(new Error(\"Failed to parse binding result as JSON\"));\
return;\
}\
}\
if (status === 0) {\
promise.resolve(result);\
} else {\
promise.reject(result);\
}\
})()";
eval(js);
},
result.empty() ? "undefined" : json_escape(result)));
}
void *window() { return window_impl(); }
void *widget() { return widget_impl(); }
void *browser_controller() { return browser_controller_impl(); };
void run() { run_impl(); }
void terminate() { terminate_impl(); }
void dispatch(std::function<void()> f) { dispatch_impl(f); }
void set_title(const std::string &title) { set_title_impl(title); }
void set_size(int width, int height, webview_hint_t hints) {
set_size_impl(width, height, hints);
}
void set_html(const std::string &html) { set_html_impl(html); }
void init(const std::string &js) { init_impl(js); }
void eval(const std::string &js) { eval_impl(js); }
void set_zoom(double level) {
const double minZoom = 0.9; // 90%
const double maxZoom = 1.2; // 120%
level = std::max(minZoom, std::min(maxZoom, level));
set_zoom_impl(level);
}
double get_zoom() { return get_zoom_impl(); }
protected:
virtual void navigate_impl(const std::string &url) = 0;
virtual void *window_impl() = 0;
virtual void *widget_impl() = 0;
virtual void *browser_controller_impl() = 0;
virtual void run_impl() = 0;
virtual void terminate_impl() = 0;
virtual void dispatch_impl(std::function<void()> f) = 0;
virtual void set_title_impl(const std::string &title) = 0;
virtual void set_size_impl(int width, int height, webview_hint_t hints) = 0;
virtual void set_html_impl(const std::string &html) = 0;
virtual void init_impl(const std::string &js) = 0;
virtual void eval_impl(const std::string &js) = 0;
virtual void set_zoom_impl(double level) = 0;
virtual double get_zoom_impl() = 0;
virtual void on_message(const std::string &msg) {
auto seq = json_parse(msg, "id", 0);
auto name = json_parse(msg, "method", 0);
auto args = json_parse(msg, "params", 0);
auto found = bindings.find(name);
if (found == bindings.end()) {
return;
}
const auto &context = found->second;
context.callback(seq, args, context.arg);
}
virtual void on_window_created() { inc_window_count(); }
virtual void on_window_destroyed(bool skip_termination = false) {
if (dec_window_count() <= 0) {
if (!skip_termination) {
terminate();
}
}
}
private:
static std::atomic_uint &window_ref_count() {
static std::atomic_uint ref_count{0};
return ref_count;
}
static unsigned int inc_window_count() { return ++window_ref_count(); }
static unsigned int dec_window_count() {
auto &count = window_ref_count();
if (count > 0) {
return --count;
}
return 0;
}
std::map<std::string, binding_ctx_t> bindings;
};
} // namespace detail
WEBVIEW_DEPRECATED_PRIVATE
inline int json_parse_c(const char *s, size_t sz, const char *key, size_t keysz,
const char **value, size_t *valuesz) {
return detail::json_parse_c(s, sz, key, keysz, value, valuesz);
}
WEBVIEW_DEPRECATED_PRIVATE
inline std::string json_escape(const std::string &s) {
return detail::json_escape(s);
}
WEBVIEW_DEPRECATED_PRIVATE
inline int json_unescape(const char *s, size_t n, char *out) {
return detail::json_unescape(s, n, out);
}
WEBVIEW_DEPRECATED_PRIVATE
inline std::string json_parse(const std::string &s, const std::string &key,
const int index) {
return detail::json_parse(s, key, index);
}
} // namespace webview
#if defined(WEBVIEW_GTK)
//
// ====================================================================
//
// This implementation uses webkit2gtk backend. It requires gtk+3.0 and
// webkit2gtk-4.0 libraries. Proper compiler flags can be retrieved via:
//
// pkg-config --cflags --libs gtk+-3.0 webkit2gtk-4.0
//
// ====================================================================
//
#include <cstdlib>
#include <JavaScriptCore/JavaScript.h>
#include <gtk/gtk.h>
#include <webkit2/webkit2.h>
#ifdef GDK_WINDOWING_X11
#include <gdk/gdkx.h>
#endif
#include <fcntl.h>
#include <sys/stat.h>
namespace webview {
namespace detail {
// Namespace containing workaround for WebKit 2.42 when using NVIDIA GPU
// driver.
// See WebKit bug: https://bugs.webkit.org/show_bug.cgi?id=261874
// Please remove all of the code in this namespace when it's no longer needed.
namespace webkit_dmabuf {
// Get environment variable. Not thread-safe.
static inline std::string get_env(const std::string &name) {
auto *value = std::getenv(name.c_str());
if (value) {
return {value};
}
return {};
}
// Set environment variable. Not thread-safe.
static inline void set_env(const std::string &name, const std::string &value) {
::setenv(name.c_str(), value.c_str(), 1);
}
// Checks whether the NVIDIA GPU driver is used based on whether the kernel
// module is loaded.
static inline bool is_using_nvidia_driver() {
struct ::stat buffer {};
if (::stat("/sys/module/nvidia", &buffer) != 0) {
return false;
}
return S_ISDIR(buffer.st_mode);
}
// Checks whether the windowing system is Wayland.
static inline bool is_wayland_display() {
if (!get_env("WAYLAND_DISPLAY").empty()) {
return true;
}
if (get_env("XDG_SESSION_TYPE") == "wayland") {
return true;
}
if (get_env("DESKTOP_SESSION").find("wayland") != std::string::npos) {
return true;
}
return false;
}
// Checks whether the GDK X11 backend is used.
// See: https://docs.gtk.org/gdk3/class.DisplayManager.html
static inline bool is_gdk_x11_backend() {
#ifdef GDK_WINDOWING_X11
auto *manager = gdk_display_manager_get();
auto *display = gdk_display_manager_get_default_display(manager);
return GDK_IS_X11_DISPLAY(display); // NOLINT(misc-const-correctness)
#else
return false;
#endif
}
// Checks whether WebKit is affected by bug when using DMA-BUF renderer.
// Returns true if all of the following conditions are met:
// - WebKit version is >= 2.42 (please narrow this down when there's a fix).
// - Environment variables are empty or not set:
// - WEBKIT_DISABLE_DMABUF_RENDERER
// - Windowing system is not Wayland.
// - GDK backend is X11.
// - NVIDIA GPU driver is used.
static inline bool is_webkit_dmabuf_bugged() {
auto wk_major = webkit_get_major_version();
auto wk_minor = webkit_get_minor_version();
// TODO: Narrow down affected WebKit version when there's a fixed version
auto is_affected_wk_version = wk_major == 2 && wk_minor >= 42;
if (!is_affected_wk_version) {
return false;
}
if (!get_env("WEBKIT_DISABLE_DMABUF_RENDERER").empty()) {
return false;
}
if (is_wayland_display()) {
return false;
}
if (!is_gdk_x11_backend()) {
return false;
}
if (!is_using_nvidia_driver()) {
return false;
}
return true;
}
// Applies workaround for WebKit DMA-BUF bug if needed.
// See WebKit bug: https://bugs.webkit.org/show_bug.cgi?id=261874
static inline void apply_webkit_dmabuf_workaround() {
if (!is_webkit_dmabuf_bugged()) {
return;
}
set_env("WEBKIT_DISABLE_DMABUF_RENDERER", "1");
}
} // namespace webkit_dmabuf
namespace webkit_symbols {
using webkit_web_view_evaluate_javascript_t =
void (*)(WebKitWebView *, const char *, gssize, const char *, const char *,
GCancellable *, GAsyncReadyCallback, gpointer);
using webkit_web_view_run_javascript_t = void (*)(WebKitWebView *,
const gchar *, GCancellable *,
GAsyncReadyCallback,
gpointer);
constexpr auto webkit_web_view_evaluate_javascript =
library_symbol<webkit_web_view_evaluate_javascript_t>(
"webkit_web_view_evaluate_javascript");
constexpr auto webkit_web_view_run_javascript =
library_symbol<webkit_web_view_run_javascript_t>(
"webkit_web_view_run_javascript");
} // namespace webkit_symbols
class gtk_webkit_engine : public engine_base {
public:
gtk_webkit_engine(bool debug, void *window)
: m_owns_window{!window}, m_window(static_cast<GtkWidget *>(window)) {
if (m_owns_window) {
if (gtk_init_check(nullptr, nullptr) == FALSE) {
return;
}
m_window = gtk_window_new(GTK_WINDOW_TOPLEVEL);
on_window_created();
g_signal_connect(G_OBJECT(m_window), "destroy",
G_CALLBACK(+[](GtkWidget *, gpointer arg) {
auto *w = static_cast<gtk_webkit_engine *>(arg);
// Widget destroyed along with window.
w->m_webview = nullptr;
w->m_window = nullptr;
w->on_window_destroyed();
}),
this);
}
webkit_dmabuf::apply_webkit_dmabuf_workaround();
// Initialize webview widget
m_webview = webkit_web_view_new();
WebKitUserContentManager *manager =
webkit_web_view_get_user_content_manager(WEBKIT_WEB_VIEW(m_webview));
g_signal_connect(manager, "script-message-received::external",
G_CALLBACK(+[](WebKitUserContentManager *,
WebKitJavascriptResult *r, gpointer arg) {
auto *w = static_cast<gtk_webkit_engine *>(arg);
char *s = get_string_from_js_result(r);
w->on_message(s);
g_free(s);
}),
this);
webkit_user_content_manager_register_script_message_handler(manager,
"external");
init("window.external={invoke:function(s){window.webkit.messageHandlers."
"external.postMessage(s);}}");
gtk_container_add(GTK_CONTAINER(m_window), GTK_WIDGET(m_webview));
gtk_widget_show(GTK_WIDGET(m_webview));
WebKitSettings *settings =
webkit_web_view_get_settings(WEBKIT_WEB_VIEW(m_webview));
webkit_settings_set_javascript_can_access_clipboard(settings, true);
if (debug) {
webkit_settings_set_enable_write_console_messages_to_stdout(settings,
true);
webkit_settings_set_enable_developer_extras(settings, true);
}
if (m_owns_window) {
gtk_widget_grab_focus(GTK_WIDGET(m_webview));
gtk_widget_show_all(m_window);
}
}
gtk_webkit_engine(const gtk_webkit_engine &) = delete;
gtk_webkit_engine &operator=(const gtk_webkit_engine &) = delete;
gtk_webkit_engine(gtk_webkit_engine &&) = delete;
gtk_webkit_engine &operator=(gtk_webkit_engine &&) = delete;
virtual ~gtk_webkit_engine() {
if (m_webview) {
gtk_widget_destroy(GTK_WIDGET(m_webview));
m_webview = nullptr;
}
if (m_window) {
if (m_owns_window) {
// Disconnect handlers to avoid callbacks invoked during destruction.
g_signal_handlers_disconnect_by_data(GTK_WINDOW(m_window), this);
gtk_window_close(GTK_WINDOW(m_window));
on_window_destroyed(true);
}
m_window = nullptr;
}
if (m_owns_window) {
// Needed for the window to close immediately.
deplete_run_loop_event_queue();
}
}
void *window_impl() override { return (void *)m_window; }
void *widget_impl() override { return (void *)m_webview; }
void *browser_controller_impl() override { return (void *)m_webview; };
void run_impl() override { gtk_main(); }
void terminate_impl() override {
dispatch_impl([] { gtk_main_quit(); });
}
void dispatch_impl(std::function<void()> f) override {
g_idle_add_full(G_PRIORITY_HIGH_IDLE, (GSourceFunc)([](void *f) -> int {
(*static_cast<dispatch_fn_t *>(f))();
return G_SOURCE_REMOVE;
}),
new std::function<void()>(f),
[](void *f) { delete static_cast<dispatch_fn_t *>(f); });
}
void set_title_impl(const std::string &title) override {
gtk_window_set_title(GTK_WINDOW(m_window), title.c_str());
}
void set_size_impl(int width, int height, webview_hint_t hints) override {
gtk_window_set_resizable(GTK_WINDOW(m_window), hints != WEBVIEW_HINT_FIXED);
if (hints == WEBVIEW_HINT_NONE) {
gtk_window_resize(GTK_WINDOW(m_window), width, height);
} else if (hints == WEBVIEW_HINT_FIXED) {
gtk_widget_set_size_request(m_window, width, height);
} else {
GdkGeometry g;
g.min_width = g.max_width = width;
g.min_height = g.max_height = height;
GdkWindowHints h =
(hints == WEBVIEW_HINT_MIN ? GDK_HINT_MIN_SIZE : GDK_HINT_MAX_SIZE);
// This defines either MIN_SIZE, or MAX_SIZE, but not both:
gtk_window_set_geometry_hints(GTK_WINDOW(m_window), nullptr, &g, h);
}
}
void navigate_impl(const std::string &url) override {
webkit_web_view_load_uri(WEBKIT_WEB_VIEW(m_webview), url.c_str());
}
void set_html_impl(const std::string &html) override {
webkit_web_view_load_html(WEBKIT_WEB_VIEW(m_webview), html.c_str(),
nullptr);
}
void init_impl(const std::string &js) override {
WebKitUserContentManager *manager =
webkit_web_view_get_user_content_manager(WEBKIT_WEB_VIEW(m_webview));
webkit_user_content_manager_add_script(
manager,
webkit_user_script_new(js.c_str(), WEBKIT_USER_CONTENT_INJECT_TOP_FRAME,
WEBKIT_USER_SCRIPT_INJECT_AT_DOCUMENT_START,
nullptr, nullptr));
}
void eval_impl(const std::string &js) override {
auto &lib = get_webkit_library();
auto wkmajor = webkit_get_major_version();
auto wkminor = webkit_get_minor_version();
if ((wkmajor == 2 && wkminor >= 40) || wkmajor > 2) {
if (auto fn =
lib.get(webkit_symbols::webkit_web_view_evaluate_javascript)) {
fn(WEBKIT_WEB_VIEW(m_webview), js.c_str(),
static_cast<gssize>(js.size()), nullptr, nullptr, nullptr, nullptr,
nullptr);
}
} else if (auto fn =
lib.get(webkit_symbols::webkit_web_view_run_javascript)) {
fn(WEBKIT_WEB_VIEW(m_webview), js.c_str(), nullptr, nullptr, nullptr);
}
}
void set_zoom_impl(double level) override {
webkit_web_view_set_zoom_level(WEBKIT_WEB_VIEW(m_webview), level);
}
double get_zoom_impl() override {
return webkit_web_view_get_zoom_level(WEBKIT_WEB_VIEW(m_webview));
}
private:
static char *get_string_from_js_result(WebKitJavascriptResult *r) {
char *s;
#if (WEBKIT_MAJOR_VERSION == 2 && WEBKIT_MINOR_VERSION >= 22) || \
WEBKIT_MAJOR_VERSION > 2
JSCValue *value = webkit_javascript_result_get_js_value(r);
s = jsc_value_to_string(value);
#else
JSGlobalContextRef ctx = webkit_javascript_result_get_global_context(r);
JSValueRef value = webkit_javascript_result_get_value(r);
JSStringRef js = JSValueToStringCopy(ctx, value, nullptr);
size_t n = JSStringGetMaximumUTF8CStringSize(js);
s = g_new(char, n);
JSStringGetUTF8CString(js, s, n);
JSStringRelease(js);
#endif
return s;
}
static const native_library &get_webkit_library() {
static const native_library non_loaded_lib;
static native_library loaded_lib;
if (loaded_lib.is_loaded()) {
return loaded_lib;
}
constexpr std::array<const char *, 2> lib_names{"libwebkit2gtk-4.1.so",
"libwebkit2gtk-4.0.so"};
auto found =
std::find_if(lib_names.begin(), lib_names.end(), [](const char *name) {
return native_library::is_loaded(name);
});
if (found == lib_names.end()) {
return non_loaded_lib;
}
loaded_lib = native_library(*found);
auto loaded = loaded_lib.is_loaded();
if (!loaded) {
return non_loaded_lib;
}
return loaded_lib;
}
// Blocks while depleting the run loop of events.
void deplete_run_loop_event_queue() {
bool done{};
dispatch([&] { done = true; });
while (!done) {
gtk_main_iteration();
}
}
bool m_owns_window{};
GtkWidget *m_window{};
GtkWidget *m_webview{};
};
} // namespace detail
using browser_engine = detail::gtk_webkit_engine;
} // namespace webview
#elif defined(WEBVIEW_COCOA)
//
// ====================================================================
//
// This implementation uses Cocoa WKWebView backend on macOS. It is
// written using ObjC runtime and uses WKWebView class as a browser runtime.
// You should pass "-framework Webkit" flag to the compiler.
//
// ====================================================================
//
#include <CoreGraphics/CoreGraphics.h>
#include <objc/NSObjCRuntime.h>
#include <objc/objc-runtime.h>
namespace webview {
namespace detail {
namespace objc {
// A convenient template function for unconditionally casting the specified
// C-like function into a function that can be called with the given return
// type and arguments. Caller takes full responsibility for ensuring that
// the function call is valid. It is assumed that the function will not
// throw exceptions.
template <typename Result, typename Callable, typename... Args>
Result invoke(Callable callable, Args... args) noexcept {
return reinterpret_cast<Result (*)(Args...)>(callable)(args...);
}
// Calls objc_msgSend.
template <typename Result, typename... Args>
Result msg_send(Args... args) noexcept {
return invoke<Result>(objc_msgSend, args...);
}
// Wrapper around NSAutoreleasePool that drains the pool on destruction.
class autoreleasepool {
public:
autoreleasepool()
: m_pool(msg_send<id>(objc_getClass("NSAutoreleasePool"),
sel_registerName("new"))) {}
~autoreleasepool() {
if (m_pool) {
msg_send<void>(m_pool, sel_registerName("drain"));
}
}
autoreleasepool(const autoreleasepool &) = delete;
autoreleasepool &operator=(const autoreleasepool &) = delete;
autoreleasepool(autoreleasepool &&) = delete;
autoreleasepool &operator=(autoreleasepool &&) = delete;
private:
id m_pool{};
};
inline id autoreleased(id object) {
msg_send<void>(object, sel_registerName("autorelease"));
return object;
}
} // namespace objc
enum NSBackingStoreType : NSUInteger { NSBackingStoreBuffered = 2 };
enum NSWindowStyleMask : NSUInteger {
NSWindowStyleMaskTitled = 1,
NSWindowStyleMaskClosable = 2,
NSWindowStyleMaskMiniaturizable = 4,
NSWindowStyleMaskResizable = 8
};
enum NSApplicationActivationPolicy : NSInteger {
NSApplicationActivationPolicyRegular = 0
};
enum WKUserScriptInjectionTime : NSInteger {
WKUserScriptInjectionTimeAtDocumentStart = 0
};
enum NSModalResponse : NSInteger { NSModalResponseOK = 1 };
// Convenient conversion of string literals.
inline id operator"" _cls(const char *s, std::size_t) {
return (id)objc_getClass(s);
}
inline SEL operator"" _sel(const char *s, std::size_t) {
return sel_registerName(s);
}
inline id operator"" _str(const char *s, std::size_t) {
return objc::msg_send<id>("NSString"_cls, "stringWithUTF8String:"_sel, s);
}
class cocoa_wkwebview_engine : public engine_base {
public:
cocoa_wkwebview_engine(bool debug, void *window)
: m_debug{debug}, m_window{static_cast<id>(window)}, m_owns_window{
!window} {
auto app = get_shared_application();
// See comments related to application lifecycle in create_app_delegate().
if (!m_owns_window) {
set_up_window();
} else {
// Only set the app delegate if it hasn't already been set.
auto delegate = objc::msg_send<id>(app, "delegate"_sel);
if (delegate) {
set_up_window();
} else {
m_app_delegate = create_app_delegate();
objc_setAssociatedObject(m_app_delegate, "webview", (id)this,
OBJC_ASSOCIATION_ASSIGN);
objc::msg_send<void>(app, "setDelegate:"_sel, m_app_delegate);
// Start the main run loop so that the app delegate gets the
// NSApplicationDidFinishLaunchingNotification notification after the run
// loop has started in order to perform further initialization.
// We need to return from this constructor so this run loop is only
// temporary.
// Skip the main loop if this isn't the first instance of this class
// because the launch event is only sent once. Instead, proceed to
// create a window.
if (get_and_set_is_first_instance()) {
objc::msg_send<void>(app, "run"_sel);
} else {
set_up_window();
}
}
}
}
cocoa_wkwebview_engine(const cocoa_wkwebview_engine &) = delete;
cocoa_wkwebview_engine &operator=(const cocoa_wkwebview_engine &) = delete;
cocoa_wkwebview_engine(cocoa_wkwebview_engine &&) = delete;
cocoa_wkwebview_engine &operator=(cocoa_wkwebview_engine &&) = delete;
virtual ~cocoa_wkwebview_engine() {
objc::autoreleasepool arp;
if (m_window) {
if (m_webview) {
if (m_webview == objc::msg_send<id>(m_window, "contentView"_sel)) {
objc::msg_send<void>(m_window, "setContentView:"_sel, nullptr);
}
objc::msg_send<void>(m_webview, "release"_sel);
m_webview = nullptr;
}
if (m_owns_window) {
// Replace delegate to avoid callbacks and other bad things during
// destruction.
objc::msg_send<void>(m_window, "setDelegate:"_sel, nullptr);
objc::msg_send<void>(m_window, "close"_sel);
on_window_destroyed(true);
}
m_window = nullptr;
}
if (m_window_delegate) {
objc::msg_send<void>(m_window_delegate, "release"_sel);
m_window_delegate = nullptr;
}
if (m_app_delegate) {
auto app = get_shared_application();
objc::msg_send<void>(app, "setDelegate:"_sel, nullptr);
// Make sure to release the delegate we created.
objc::msg_send<void>(m_app_delegate, "release"_sel);
m_app_delegate = nullptr;
}
if (m_owns_window) {
// Needed for the window to close immediately.
deplete_run_loop_event_queue();
}
// TODO: Figure out why m_manager is still alive after the autoreleasepool
// has been drained.
}
void *window_impl() override { return (void *)m_window; }
void *widget_impl() override { return (void *)m_webview; }
void *browser_controller_impl() override { return (void *)m_webview; };
void terminate_impl() override { stop_run_loop(); }
void run_impl() override {
auto app = get_shared_application();
objc::msg_send<void>(app, "run"_sel);
}
void dispatch_impl(std::function<void()> f) override {
dispatch_async_f(dispatch_get_main_queue(), new dispatch_fn_t(f),
(dispatch_function_t)([](void *arg) {
auto f = static_cast<dispatch_fn_t *>(arg);
(*f)();
delete f;
}));
}
void set_title_impl(const std::string &title) override {
objc::autoreleasepool arp;
objc::msg_send<void>(m_window, "setTitle:"_sel,
objc::msg_send<id>("NSString"_cls,
"stringWithUTF8String:"_sel,
title.c_str()));
}
void set_size_impl(int width, int height, webview_hint_t hints) override {
objc::autoreleasepool arp;
auto style = static_cast<NSWindowStyleMask>(
NSWindowStyleMaskTitled | NSWindowStyleMaskClosable |
NSWindowStyleMaskMiniaturizable);
if (hints != WEBVIEW_HINT_FIXED) {
style =
static_cast<NSWindowStyleMask>(style | NSWindowStyleMaskResizable);
}
objc::msg_send<void>(m_window, "setStyleMask:"_sel, style);
if (hints == WEBVIEW_HINT_MIN) {
objc::msg_send<void>(m_window, "setContentMinSize:"_sel,
CGSizeMake(width, height));
} else if (hints == WEBVIEW_HINT_MAX) {
objc::msg_send<void>(m_window, "setContentMaxSize:"_sel,
CGSizeMake(width, height));
} else {
objc::msg_send<void>(m_window, "setFrame:display:animate:"_sel,
CGRectMake(0, 0, width, height), YES, NO);
}
objc::msg_send<void>(m_window, "center"_sel);
}
void navigate_impl(const std::string &url) override {
objc::autoreleasepool arp;
auto nsurl = objc::msg_send<id>(
"NSURL"_cls, "URLWithString:"_sel,
objc::msg_send<id>("NSString"_cls, "stringWithUTF8String:"_sel,
url.c_str()));
objc::msg_send<void>(
m_webview, "loadRequest:"_sel,
objc::msg_send<id>("NSURLRequest"_cls, "requestWithURL:"_sel, nsurl));
}
void set_html_impl(const std::string &html) override {
objc::autoreleasepool arp;
objc::msg_send<void>(m_webview, "loadHTMLString:baseURL:"_sel,
objc::msg_send<id>("NSString"_cls,
"stringWithUTF8String:"_sel,
html.c_str()),
nullptr);
}
void init_impl(const std::string &js) override {
objc::autoreleasepool arp;
auto script = objc::autoreleased(objc::msg_send<id>(
objc::msg_send<id>("WKUserScript"_cls, "alloc"_sel),
"initWithSource:injectionTime:forMainFrameOnly:"_sel,
objc::msg_send<id>("NSString"_cls, "stringWithUTF8String:"_sel,
js.c_str()),
WKUserScriptInjectionTimeAtDocumentStart, YES));
objc::msg_send<void>(m_manager, "addUserScript:"_sel, script);
}
void eval_impl(const std::string &js) override {
objc::autoreleasepool arp;
objc::msg_send<void>(m_webview, "evaluateJavaScript:completionHandler:"_sel,
objc::msg_send<id>("NSString"_cls,
"stringWithUTF8String:"_sel,
js.c_str()),
nullptr);
}
void set_zoom_impl(double level) override {
objc::autoreleasepool arp;
objc::msg_send<void>(m_webview, "setPageZoom:"_sel, level);
}
double get_zoom_impl() override {
objc::autoreleasepool arp;
return objc::msg_send<double>(m_webview, "pageZoom"_sel);
}
private:
id create_app_delegate() {
objc::autoreleasepool arp;
constexpr auto class_name = "WebviewAppDelegate";
// Avoid crash due to registering same class twice
auto cls = objc_lookUpClass(class_name);
if (!cls) {
// Note: Avoid registering the class name "AppDelegate" as it is the
// default name in projects created with Xcode, and using the same name
// causes objc_registerClassPair to crash.
cls = objc_allocateClassPair((Class) "NSResponder"_cls, class_name, 0);
class_addProtocol(cls, objc_getProtocol("NSTouchBarProvider"));
class_addMethod(cls,
"applicationShouldTerminateAfterLastWindowClosed:"_sel,
(IMP)(+[](id, SEL, id) -> BOOL { return NO; }), "c@:@");
class_addMethod(cls, "applicationDidFinishLaunching:"_sel,
(IMP)(+[](id self, SEL, id notification) {
auto app =
objc::msg_send<id>(notification, "object"_sel);
auto w = get_associated_webview(self);
w->on_application_did_finish_launching(self, app);
}),
"v@:@");
objc_registerClassPair(cls);
}
return objc::msg_send<id>((id)cls, "new"_sel);
}
id create_script_message_handler() {
objc::autoreleasepool arp;
constexpr auto class_name = "WebviewWKScriptMessageHandler";
// Avoid crash due to registering same class twice
auto cls = objc_lookUpClass(class_name);
if (!cls) {
cls = objc_allocateClassPair((Class) "NSResponder"_cls, class_name, 0);
class_addProtocol(cls, objc_getProtocol("WKScriptMessageHandler"));
class_addMethod(
cls, "userContentController:didReceiveScriptMessage:"_sel,
(IMP)(+[](id self, SEL, id, id msg) {
auto w = get_associated_webview(self);
w->on_message(objc::msg_send<const char *>(
objc::msg_send<id>(msg, "body"_sel), "UTF8String"_sel));
}),
"v@:@@");
objc_registerClassPair(cls);
}
auto instance = objc::msg_send<id>((id)cls, "new"_sel);
objc_setAssociatedObject(instance, "webview", (id)this,
OBJC_ASSOCIATION_ASSIGN);
return instance;
}
static id create_webkit_ui_delegate() {
objc::autoreleasepool arp;
constexpr auto class_name = "WebviewWKUIDelegate";
// Avoid crash due to registering same class twice
auto cls = objc_lookUpClass(class_name);
if (!cls) {
cls = objc_allocateClassPair((Class) "NSObject"_cls, class_name, 0);
class_addProtocol(cls, objc_getProtocol("WKUIDelegate"));
class_addMethod(
cls,
"webView:runOpenPanelWithParameters:initiatedByFrame:completionHandler:"_sel,
(IMP)(+[](id, SEL, id, id parameters, id, id completion_handler) {
auto allows_multiple_selection =
objc::msg_send<BOOL>(parameters, "allowsMultipleSelection"_sel);
auto allows_directories =
objc::msg_send<BOOL>(parameters, "allowsDirectories"_sel);
// Show a panel for selecting files.
auto panel = objc::msg_send<id>("NSOpenPanel"_cls, "openPanel"_sel);
objc::msg_send<void>(panel, "setCanChooseFiles:"_sel, YES);
objc::msg_send<void>(panel, "setCanChooseDirectories:"_sel,
allows_directories);
objc::msg_send<void>(panel, "setAllowsMultipleSelection:"_sel,
allows_multiple_selection);
auto modal_response =
objc::msg_send<NSModalResponse>(panel, "runModal"_sel);
// Get the URLs for the selected files. If the modal was canceled
// then we pass null to the completion handler to signify
// cancellation.
id urls = modal_response == NSModalResponseOK
? objc::msg_send<id>(panel, "URLs"_sel)
: nullptr;
// Invoke the completion handler block.
auto sig = objc::msg_send<id>(
"NSMethodSignature"_cls, "signatureWithObjCTypes:"_sel, "v@?@");
auto invocation = objc::msg_send<id>(
"NSInvocation"_cls, "invocationWithMethodSignature:"_sel, sig);
objc::msg_send<void>(invocation, "setTarget:"_sel,
completion_handler);
objc::msg_send<void>(invocation, "setArgument:atIndex:"_sel, &urls,
1);
objc::msg_send<void>(invocation, "invoke"_sel);
}),
"v@:@@@@");
objc_registerClassPair(cls);
}
return objc::msg_send<id>((id)cls, "new"_sel);
}
static id create_window_delegate() {
objc::autoreleasepool arp;
constexpr auto class_name = "WebviewNSWindowDelegate";
// Avoid crash due to registering same class twice
auto cls = objc_lookUpClass(class_name);
if (!cls) {
cls = objc_allocateClassPair((Class) "NSObject"_cls, class_name, 0);
class_addProtocol(cls, objc_getProtocol("NSWindowDelegate"));
class_addMethod(cls, "windowWillClose:"_sel,
(IMP)(+[](id self, SEL, id notification) {
auto window =
objc::msg_send<id>(notification, "object"_sel);
auto w = get_associated_webview(self);
w->on_window_will_close(self, window);
}),
"v@:@");
objc_registerClassPair(cls);
}
return objc::msg_send<id>((id)cls, "new"_sel);
}
static id get_shared_application() {
return objc::msg_send<id>("NSApplication"_cls, "sharedApplication"_sel);
}
static cocoa_wkwebview_engine *get_associated_webview(id object) {
auto w =
(cocoa_wkwebview_engine *)objc_getAssociatedObject(object, "webview");
assert(w);
return w;
}
static id get_main_bundle() noexcept {
return objc::msg_send<id>("NSBundle"_cls, "mainBundle"_sel);
}
static bool is_app_bundled() noexcept {
auto bundle = get_main_bundle();
if (!bundle) {
return false;
}
auto bundle_path = objc::msg_send<id>(bundle, "bundlePath"_sel);
auto bundled =
objc::msg_send<BOOL>(bundle_path, "hasSuffix:"_sel, ".app"_str);
return !!bundled;
}
void on_application_did_finish_launching(id /*delegate*/, id app) {
// See comments related to application lifecycle in create_app_delegate().
if (m_owns_window) {
// Stop the main run loop so that we can return
// from the constructor.
stop_run_loop();
}
// Activate the app if it is not bundled.
// Bundled apps launched from Finder are activated automatically but
// otherwise not. Activating the app even when it has been launched from
// Finder does not seem to be harmful but calling this function is rarely
// needed as proper activation is normally taken care of for us.
// Bundled apps have a default activation policy of
// NSApplicationActivationPolicyRegular while non-bundled apps have a
// default activation policy of NSApplicationActivationPolicyProhibited.
if (!is_app_bundled()) {
// "setActivationPolicy:" must be invoked before
// "activateIgnoringOtherApps:" for activation to work.
objc::msg_send<void>(app, "setActivationPolicy:"_sel,
NSApplicationActivationPolicyRegular);
// Activate the app regardless of other active apps.
// This can be obtrusive so we only do it when necessary.
objc::msg_send<void>(app, "activateIgnoringOtherApps:"_sel, YES);
}
set_up_window();
}
void on_window_will_close(id /*delegate*/, id /*window*/) {
// Widget destroyed along with window.
m_webview = nullptr;
m_window = nullptr;
dispatch([this] { on_window_destroyed(); });
}
void set_up_window() {
objc::autoreleasepool arp;
// Main window
if (m_owns_window) {
m_window = objc::msg_send<id>("NSWindow"_cls, "alloc"_sel);
auto style = NSWindowStyleMaskTitled;
m_window = objc::msg_send<id>(
m_window, "initWithContentRect:styleMask:backing:defer:"_sel,
CGRectMake(0, 0, 0, 0), style, NSBackingStoreBuffered, NO);
m_window_delegate = create_window_delegate();
objc_setAssociatedObject(m_window_delegate, "webview", (id)this,
OBJC_ASSOCIATION_ASSIGN);
objc::msg_send<void>(m_window, "setDelegate:"_sel, m_window_delegate);
on_window_created();
}
set_up_web_view();
objc::msg_send<void>(m_window, "setContentView:"_sel, m_webview);
if (m_owns_window) {
// objc::msg_send<void>(m_window, "makeKeyAndOrderFront:"_sel, nullptr);
}
}
void set_up_web_view() {
objc::autoreleasepool arp;
auto config = objc::autoreleased(
objc::msg_send<id>("WKWebViewConfiguration"_cls, "new"_sel));
m_manager = objc::msg_send<id>(config, "userContentController"_sel);
m_webview = objc::msg_send<id>("WKWebView"_cls, "alloc"_sel);
auto preferences = objc::msg_send<id>(config, "preferences"_sel);
auto yes_value =
objc::msg_send<id>("NSNumber"_cls, "numberWithBool:"_sel, YES);
if (m_debug) {
// Equivalent Obj-C:
// [[config preferences] setValue:@YES forKey:@"developerExtrasEnabled"];
objc::msg_send<id>(preferences, "setValue:forKey:"_sel, yes_value,
"developerExtrasEnabled"_str);
}
// Equivalent Obj-C:
// [[config preferences] setValue:@YES forKey:@"fullScreenEnabled"];
objc::msg_send<id>(preferences, "setValue:forKey:"_sel, yes_value,
"fullScreenEnabled"_str);
// Equivalent Obj-C:
// [[config preferences] setValue:@YES forKey:@"javaScriptCanAccessClipboard"];
objc::msg_send<id>(preferences, "setValue:forKey:"_sel, yes_value,
"javaScriptCanAccessClipboard"_str);
// Equivalent Obj-C:
// [[config preferences] setValue:@YES forKey:@"DOMPasteAllowed"];
objc::msg_send<id>(preferences, "setValue:forKey:"_sel, yes_value,
"DOMPasteAllowed"_str);
auto ui_delegate = objc::autoreleased(create_webkit_ui_delegate());
objc::msg_send<void>(m_webview, "initWithFrame:configuration:"_sel,
CGRectMake(0, 0, 0, 0), config);
objc::msg_send<void>(m_webview, "setUIDelegate:"_sel, ui_delegate);
if (m_debug) {
// Explicitly make WKWebView inspectable via Safari on OS versions that
// disable the feature by default (macOS 13.3 and later) and support
// enabling it. According to Apple, the behavior on older OS versions is
// for content to always be inspectable in "debug builds".
// Testing shows that this is true for macOS 12.6 but somehow not 10.15.
// https://webkit.org/blog/13936/enabling-the-inspection-of-web-content-in-apps/
#if defined(__has_builtin)
#if __has_builtin(__builtin_available)
if (__builtin_available(macOS 13.3, iOS 16.4, tvOS 16.4, *)) {
objc::msg_send<void>(
m_webview, "setInspectable:"_sel,
objc::msg_send<id>("NSNumber"_cls, "numberWithBool:"_sel, YES));
}
#else
#error __builtin_available not supported by compiler
#endif
#else
#error __has_builtin not supported by compiler
#endif
}
auto script_message_handler =
objc::autoreleased(create_script_message_handler());
objc::msg_send<void>(m_manager, "addScriptMessageHandler:name:"_sel,
script_message_handler, "external"_str);
init(R""(
window.external = {
invoke: function(s) {
window.webkit.messageHandlers.external.postMessage(s);
},
};
)"");
}
void stop_run_loop() {
objc::autoreleasepool arp;
auto app = get_shared_application();
// Request the run loop to stop. This doesn't immediately stop the loop.
objc::msg_send<void>(app, "stop:"_sel, nullptr);
// The run loop will stop after processing an NSEvent.
// Event type: NSEventTypeApplicationDefined (macOS 10.12+),
// NSApplicationDefined (macOS 10.0–10.12)
int type = 15;
auto event = objc::msg_send<id>(
"NSEvent"_cls,
"otherEventWithType:location:modifierFlags:timestamp:windowNumber:context:subtype:data1:data2:"_sel,
type, CGPointMake(0, 0), 0, 0, 0, nullptr, 0, 0, 0);
objc::msg_send<void>(app, "postEvent:atStart:"_sel, event, YES);
}
static bool get_and_set_is_first_instance() noexcept {
static std::atomic_bool first{true};
bool temp = first;
if (temp) {
first = false;
}
return temp;
}
// Blocks while depleting the run loop of events.
void deplete_run_loop_event_queue() {
objc::autoreleasepool arp;
auto app = get_shared_application();
bool done{};
dispatch([&] { done = true; });
auto mask = NSUIntegerMax; // NSEventMaskAny
// NSDefaultRunLoopMode
auto mode = objc::msg_send<id>("NSString"_cls, "stringWithUTF8String:"_sel,
"kCFRunLoopDefaultMode");
while (!done) {
objc::autoreleasepool arp;
auto event = objc::msg_send<id>(
app, "nextEventMatchingMask:untilDate:inMode:dequeue:"_sel, mask,
nullptr, mode, YES);
if (event) {
objc::msg_send<void>(app, "sendEvent:"_sel, event);
}
}
}
bool m_debug{};
id m_app_delegate{};
id m_window_delegate{};
id m_window{};
id m_webview{};
id m_manager{};
bool m_owns_window{};
};
} // namespace detail
using browser_engine = detail::cocoa_wkwebview_engine;
} // namespace webview
#elif defined(WEBVIEW_EDGE)
//
// ====================================================================
//
// This implementation uses Win32 API to create a native window. It
// uses Edge/Chromium webview2 backend as a browser engine.
//
// ====================================================================
//
#define WIN32_LEAN_AND_MEAN
#include <shlobj.h>
#include <shlwapi.h>
#include <stdlib.h>
#include <windows.h>
#include <shellapi.h>
#include <wrl.h>
#include "WebView2.h"
#ifdef _MSC_VER
#pragma comment(lib, "advapi32.lib")
#pragma comment(lib, "ole32.lib")
#pragma comment(lib, "shell32.lib")
#pragma comment(lib, "shlwapi.lib")
#pragma comment(lib, "user32.lib")
#pragma comment(lib, "version.lib")
#endif
namespace webview {
namespace detail {
using msg_cb_t = std::function<void(const std::string)>;
// Parses a version string with 1-4 integral components, e.g. "1.2.3.4".
// Missing or invalid components default to 0, and excess components are ignored.
template <typename T>
std::array<unsigned int, 4>
parse_version(const std::basic_string<T> &version) noexcept {
auto parse_component = [](auto sb, auto se) -> unsigned int {
try {
auto n = std::stol(std::basic_string<T>(sb, se));
return n < 0 ? 0 : n;
} catch (std::exception &) {
return 0;
}
};
auto end = version.end();
auto sb = version.begin(); // subrange begin
auto se = sb; // subrange end
unsigned int ci = 0; // component index
std::array<unsigned int, 4> components{};
while (sb != end && se != end && ci < components.size()) {
if (*se == static_cast<T>('.')) {
components[ci++] = parse_component(sb, se);
sb = ++se;
continue;
}
++se;
}
if (sb < se && ci < components.size()) {
components[ci] = parse_component(sb, se);
}
return components;
}
template <typename T, std::size_t Length>
auto parse_version(const T (&version)[Length]) noexcept {
return parse_version(std::basic_string<T>(version, Length));
}
std::wstring get_file_version_string(const std::wstring &file_path) noexcept {
DWORD dummy_handle; // Unused
DWORD info_buffer_length =
GetFileVersionInfoSizeW(file_path.c_str(), &dummy_handle);
if (info_buffer_length == 0) {
return std::wstring();
}
std::vector<char> info_buffer;
info_buffer.reserve(info_buffer_length);
if (!GetFileVersionInfoW(file_path.c_str(), 0, info_buffer_length,
info_buffer.data())) {
return std::wstring();
}
auto sub_block = L"\\StringFileInfo\\040904B0\\ProductVersion";
LPWSTR version = nullptr;
unsigned int version_length = 0;
if (!VerQueryValueW(info_buffer.data(), sub_block,
reinterpret_cast<LPVOID *>(&version), &version_length)) {
return std::wstring();
}
if (!version || version_length == 0) {
return std::wstring();
}
return std::wstring(version, version_length);
}
// A wrapper around COM library initialization. Calls CoInitializeEx in the
// constructor and CoUninitialize in the destructor.
class com_init_wrapper {
public:
com_init_wrapper() = default;
com_init_wrapper(DWORD dwCoInit) {
// We can safely continue as long as COM was either successfully
// initialized or already initialized.
// RPC_E_CHANGED_MODE means that CoInitializeEx was already called with
// a different concurrency model.
switch (CoInitializeEx(nullptr, dwCoInit)) {
case S_OK:
case S_FALSE:
m_initialized = true;
break;
}
}
~com_init_wrapper() {
if (m_initialized) {
CoUninitialize();
m_initialized = false;
}
}
com_init_wrapper(const com_init_wrapper &other) = delete;
com_init_wrapper &operator=(const com_init_wrapper &other) = delete;
com_init_wrapper(com_init_wrapper &&other) { *this = std::move(other); }
com_init_wrapper &operator=(com_init_wrapper &&other) {
if (this == &other) {
return *this;
}
m_initialized = std::exchange(other.m_initialized, false);
return *this;
}
bool is_initialized() const { return m_initialized; }
private:
bool m_initialized = false;
};
namespace ntdll_symbols {
using RtlGetVersion_t =
unsigned int /*NTSTATUS*/ (WINAPI *)(RTL_OSVERSIONINFOW *);
constexpr auto RtlGetVersion = library_symbol<RtlGetVersion_t>("RtlGetVersion");
} // namespace ntdll_symbols
namespace user32_symbols {
using DPI_AWARENESS_CONTEXT = HANDLE;
using SetProcessDpiAwarenessContext_t = BOOL(WINAPI *)(DPI_AWARENESS_CONTEXT);
using SetProcessDPIAware_t = BOOL(WINAPI *)();
using GetDpiForWindow_t = UINT(WINAPI *)(HWND);
using EnableNonClientDpiScaling_t = BOOL(WINAPI *)(HWND);
using AdjustWindowRectExForDpi_t = BOOL(WINAPI *)(LPRECT, DWORD, BOOL, DWORD,
UINT);
using GetWindowDpiAwarenessContext_t = DPI_AWARENESS_CONTEXT(WINAPI *)(HWND);
using AreDpiAwarenessContextsEqual_t = BOOL(WINAPI *)(DPI_AWARENESS_CONTEXT,
DPI_AWARENESS_CONTEXT);
// Use intptr_t as the underlying type because we need to
// reinterpret_cast<DPI_AWARENESS_CONTEXT> which is a pointer.
// Available since Windows 10, version 1607
enum class dpi_awareness : intptr_t {
per_monitor_v2_aware = -4, // Available since Windows 10, version 1703
per_monitor_aware = -3
};
constexpr auto SetProcessDpiAwarenessContext =
library_symbol<SetProcessDpiAwarenessContext_t>(
"SetProcessDpiAwarenessContext");
constexpr auto SetProcessDPIAware =
library_symbol<SetProcessDPIAware_t>("SetProcessDPIAware");
constexpr auto GetDpiForWindow =
library_symbol<GetDpiForWindow_t>("GetDpiForWindow");
constexpr auto EnableNonClientDpiScaling =
library_symbol<EnableNonClientDpiScaling_t>("EnableNonClientDpiScaling");
constexpr auto AdjustWindowRectExForDpi =
library_symbol<AdjustWindowRectExForDpi_t>("AdjustWindowRectExForDpi");
constexpr auto GetWindowDpiAwarenessContext =
library_symbol<GetWindowDpiAwarenessContext_t>(
"GetWindowDpiAwarenessContext");
constexpr auto AreDpiAwarenessContextsEqual =
library_symbol<AreDpiAwarenessContextsEqual_t>(
"AreDpiAwarenessContextsEqual");
} // namespace user32_symbols
// MARGINS structure for DwmExtendFrameIntoClientArea
typedef struct _MARGINS {
int cxLeftWidth;
int cxRightWidth;
int cyTopHeight;
int cyBottomHeight;
} MARGINS;
namespace dwmapi_symbols {
typedef enum {
// This undocumented value is used instead of DWMWA_USE_IMMERSIVE_DARK_MODE
// on Windows 10 older than build 19041 (2004/20H1).
DWMWA_USE_IMMERSIVE_DARK_MODE_BEFORE_V10_0_19041 = 19,
// Documented as being supported since Windows 11 build 22000 (21H2) but it
// works since Windows 10 build 19041 (2004/20H1).
DWMWA_USE_IMMERSIVE_DARK_MODE = 20
} DWMWINDOWATTRIBUTE;
using DwmSetWindowAttribute_t = HRESULT(WINAPI *)(HWND, DWORD, LPCVOID, DWORD);
using DwmExtendFrameIntoClientArea_t = HRESULT(WINAPI *)(HWND, const MARGINS *);
constexpr auto DwmSetWindowAttribute =
library_symbol<DwmSetWindowAttribute_t>("DwmSetWindowAttribute");
constexpr auto DwmExtendFrameIntoClientArea =
library_symbol<DwmExtendFrameIntoClientArea_t>("DwmExtendFrameIntoClientArea");
} // namespace dwmapi_symbols
namespace shcore_symbols {
typedef enum { PROCESS_PER_MONITOR_DPI_AWARE = 2 } PROCESS_DPI_AWARENESS;
using SetProcessDpiAwareness_t = HRESULT(WINAPI *)(PROCESS_DPI_AWARENESS);
constexpr auto SetProcessDpiAwareness =
library_symbol<SetProcessDpiAwareness_t>("SetProcessDpiAwareness");
} // namespace shcore_symbols
class reg_key {
public:
explicit reg_key(HKEY root_key, const wchar_t *sub_key, DWORD options,
REGSAM sam_desired) {
HKEY handle;
auto status =
RegOpenKeyExW(root_key, sub_key, options, sam_desired, &handle);
if (status == ERROR_SUCCESS) {
m_handle = handle;
}
}
explicit reg_key(HKEY root_key, const std::wstring &sub_key, DWORD options,
REGSAM sam_desired)
: reg_key(root_key, sub_key.c_str(), options, sam_desired) {}
virtual ~reg_key() {
if (m_handle) {
RegCloseKey(m_handle);
m_handle = nullptr;
}
}
reg_key(const reg_key &other) = delete;
reg_key &operator=(const reg_key &other) = delete;
reg_key(reg_key &&other) = delete;
reg_key &operator=(reg_key &&other) = delete;
bool is_open() const { return !!m_handle; }
bool get_handle() const { return m_handle; }
template <typename Container>
void query_bytes(const wchar_t *name, Container &result) const {
DWORD buf_length = 0;
// Get the size of the data in bytes.
auto status = RegQueryValueExW(m_handle, name, nullptr, nullptr, nullptr,
&buf_length);
if (status != ERROR_SUCCESS && status != ERROR_MORE_DATA) {
result.resize(0);
return;
}
// Read the data.
result.resize(buf_length / sizeof(typename Container::value_type));
auto *buf = reinterpret_cast<LPBYTE>(&result[0]);
status =
RegQueryValueExW(m_handle, name, nullptr, nullptr, buf, &buf_length);
if (status != ERROR_SUCCESS) {
result.resize(0);
return;
}
}
std::wstring query_string(const wchar_t *name) const {
std::wstring result;
query_bytes(name, result);
// Remove trailing null-characters.
for (std::size_t length = result.size(); length > 0; --length) {
if (result[length - 1] != 0) {
result.resize(length);
break;
}
}
return result;
}
unsigned int query_uint(const wchar_t *name,
unsigned int default_value) const {
std::vector<char> data;
query_bytes(name, data);
if (data.size() < sizeof(DWORD)) {
return default_value;
}
return static_cast<unsigned int>(*reinterpret_cast<DWORD *>(data.data()));
}
private:
HKEY m_handle = nullptr;
};
// Compare the specified version against the OS version.
// Returns less than 0 if the OS version is less.
// Returns 0 if the versions are equal.
// Returns greater than 0 if the specified version is greater.
inline int compare_os_version(unsigned int major, unsigned int minor,
unsigned int build) {
// Use RtlGetVersion both to bypass potential issues related to
// VerifyVersionInfo and manifests, and because both GetVersion and
// GetVersionEx are deprecated.
auto ntdll = native_library(L"ntdll.dll");
if (auto fn = ntdll.get(ntdll_symbols::RtlGetVersion)) {
RTL_OSVERSIONINFOW vi{};
vi.dwOSVersionInfoSize = sizeof(vi);
if (fn(&vi) != 0) {
return false;
}
if (vi.dwMajorVersion == major) {
if (vi.dwMinorVersion == minor) {
return static_cast<int>(vi.dwBuildNumber) - static_cast<int>(build);
}
return static_cast<int>(vi.dwMinorVersion) - static_cast<int>(minor);
}
return static_cast<int>(vi.dwMajorVersion) - static_cast<int>(major);
}
return false;
}
inline bool is_per_monitor_v2_awareness_available() {
// Windows 10, version 1703
return compare_os_version(10, 0, 15063) >= 0;
}
inline bool enable_dpi_awareness() {
auto user32 = native_library(L"user32.dll");
if (auto fn = user32.get(user32_symbols::SetProcessDpiAwarenessContext)) {
auto dpi_awareness =
reinterpret_cast<user32_symbols::DPI_AWARENESS_CONTEXT>(
is_per_monitor_v2_awareness_available()
? user32_symbols::dpi_awareness::per_monitor_v2_aware
: user32_symbols::dpi_awareness::per_monitor_aware);
if (fn(dpi_awareness)) {
return true;
}
return GetLastError() == ERROR_ACCESS_DENIED;
}
if (auto shcore = native_library(L"shcore.dll")) {
if (auto fn = shcore.get(shcore_symbols::SetProcessDpiAwareness)) {
auto result = fn(shcore_symbols::PROCESS_PER_MONITOR_DPI_AWARE);
return result == S_OK || result == E_ACCESSDENIED;
}
}
if (auto fn = user32.get(user32_symbols::SetProcessDPIAware)) {
return !!fn();
}
return true;
}
inline bool enable_non_client_dpi_scaling_if_needed(HWND window) {
auto user32 = native_library(L"user32.dll");
auto get_ctx_fn = user32.get(user32_symbols::GetWindowDpiAwarenessContext);
if (!get_ctx_fn) {
return true;
}
auto awareness = get_ctx_fn(window);
if (!awareness) {
return false;
}
auto ctx_equal_fn = user32.get(user32_symbols::AreDpiAwarenessContextsEqual);
if (!ctx_equal_fn) {
return true;
}
// EnableNonClientDpiScaling is only needed with per monitor v1 awareness.
auto per_monitor = reinterpret_cast<user32_symbols::DPI_AWARENESS_CONTEXT>(
user32_symbols::dpi_awareness::per_monitor_aware);
if (!ctx_equal_fn(awareness, per_monitor)) {
return true;
}
auto enable_fn = user32.get(user32_symbols::EnableNonClientDpiScaling);
if (!enable_fn) {
return true;
}
return !!enable_fn(window);
}
constexpr int get_default_window_dpi() {
constexpr const int default_dpi = 96; // USER_DEFAULT_SCREEN_DPI
return default_dpi;
}
inline int get_window_dpi(HWND window) {
auto user32 = native_library(L"user32.dll");
if (auto fn = user32.get(user32_symbols::GetDpiForWindow)) {
auto dpi = static_cast<int>(fn(window));
return dpi;
}
return get_default_window_dpi();
}
constexpr int scale_value_for_dpi(int value, int from_dpi, int to_dpi) {
return (value * to_dpi) / from_dpi;
}
constexpr SIZE scale_size(int width, int height, int from_dpi, int to_dpi) {
auto scaled_width = scale_value_for_dpi(width, from_dpi, to_dpi);
auto scaled_height = scale_value_for_dpi(height, from_dpi, to_dpi);
return {scaled_width, scaled_height};
}
inline SIZE make_window_frame_size(HWND window, int width, int height,
int dpi) {
auto style = GetWindowLong(window, GWL_STYLE);
RECT r{0, 0, width, height};
auto user32 = native_library(L"user32.dll");
if (auto fn = user32.get(user32_symbols::AdjustWindowRectExForDpi)) {
fn(&r, style, FALSE, 0, static_cast<UINT>(dpi));
} else {
AdjustWindowRect(&r, style, 0);
}
auto frame_width = r.right - r.left;
auto frame_height = r.bottom - r.top;
return {frame_width, frame_height};
}
inline bool is_dark_theme_enabled() {
constexpr auto *sub_key =
L"SOFTWARE\\Microsoft\\Windows\\CurrentVersion\\Themes\\Personalize";
reg_key key(HKEY_CURRENT_USER, sub_key, 0, KEY_READ);
if (!key.is_open()) {
// Default is light theme
return false;
}
return key.query_uint(L"AppsUseLightTheme", 1) == 0;
}
inline void apply_window_theme(HWND window) {
auto dark_theme_enabled = is_dark_theme_enabled();
// Use "immersive dark mode" on systems that support it.
// Changes the color of the window's title bar (light or dark).
BOOL use_dark_mode{dark_theme_enabled ? TRUE : FALSE};
static native_library dwmapi{L"dwmapi.dll"};
if (auto fn = dwmapi.get(dwmapi_symbols::DwmSetWindowAttribute)) {
// Try the modern, documented attribute before the older, undocumented one.
if (fn(window, dwmapi_symbols::DWMWA_USE_IMMERSIVE_DARK_MODE,
&use_dark_mode, sizeof(use_dark_mode)) != S_OK) {
fn(window,
dwmapi_symbols::DWMWA_USE_IMMERSIVE_DARK_MODE_BEFORE_V10_0_19041,
&use_dark_mode, sizeof(use_dark_mode));
}
}
// adds dark/light colors to the titlebar instead of transparent
if (auto fn = dwmapi.get(dwmapi_symbols::DwmExtendFrameIntoClientArea)) {
MARGINS margins = { 0, 0, 0, 32 }; // No extension to preserve button alignment
fn(window, &margins);
}
}
// Enable built-in WebView2Loader implementation by default.
#ifndef WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL
#define WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL 1
#endif
// Link WebView2Loader.dll explicitly by default only if the built-in
// implementation is enabled.
#ifndef WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK
#define WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL
#endif
// Explicit linking of WebView2Loader.dll should be used along with
// the built-in implementation.
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1 && \
WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK != 1
#undef WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK
#error Please set WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK=1.
#endif
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1
// Gets the last component of a Windows native file path.
// For example, if the path is "C:\a\b" then the result is "b".
template <typename T>
std::basic_string<T>
get_last_native_path_component(const std::basic_string<T> &path) {
auto pos = path.find_last_of(static_cast<T>('\\'));
if (pos != std::basic_string<T>::npos) {
return path.substr(pos + 1);
}
return std::basic_string<T>();
}
#endif /* WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL */
template <typename T> struct cast_info_t {
using type = T;
IID iid;
};
namespace mswebview2 {
static constexpr IID IID_ICoreWebView2ScriptDialogOpeningEventHandler{
0xef381bf9,
0xafa8,
0x4e37,
{0x91, 0xc4, 0x8a, 0xc4, 0x85, 0x24, 0xbd, 0xfb}};
static constexpr IID
IID_ICoreWebView2CreateCoreWebView2ControllerCompletedHandler{
0x6C4819F3,
0xC9B7,
0x4260,
{0x81, 0x27, 0xC9, 0xF5, 0xBD, 0xE7, 0xF6, 0x8C}};
static constexpr IID
IID_ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler{
0x4E8A3389,
0xC9D8,
0x4BD2,
{0xB6, 0xB5, 0x12, 0x4F, 0xEE, 0x6C, 0xC1, 0x4D}};
static constexpr IID IID_ICoreWebView2PermissionRequestedEventHandler{
0x15E1C6A3,
0xC72A,
0x4DF3,
{0x91, 0xD7, 0xD0, 0x97, 0xFB, 0xEC, 0x6B, 0xFD}};
static constexpr IID IID_ICoreWebView2WebMessageReceivedEventHandler{
0x57213F19,
0x00E6,
0x49FA,
{0x8E, 0x07, 0x89, 0x8E, 0xA0, 0x1E, 0xCB, 0xD2}};
static constexpr IID IID_ICoreWebView2NewWindowRequestedEventHandler{
0xD4C185FE,
0xC81C,
0x481B,
{0xB9, 0x4A, 0x11, 0xD1, 0xF9, 0x64, 0xEA, 0x5C}};
static constexpr IID IID_ICoreWebView2NavigationStartingEventHandler{
0x9ADBE429,
0xF36D,
0x432B,
{0x9D, 0xDC, 0xCE, 0xC5, 0x21, 0xB0, 0x03, 0x49}};
static constexpr IID IID_ICoreWebView2ContextMenuRequestedEventHandler{
0x04d3fe1d,
0xab87,
0x42fb,
{0xa8,0x98,0xda,0x24,0x1d,0x35,0xb6,0x3c}};
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1
enum class webview2_runtime_type { installed = 0, embedded = 1 };
namespace webview2_symbols {
using CreateWebViewEnvironmentWithOptionsInternal_t =
HRESULT(STDMETHODCALLTYPE *)(
bool, webview2_runtime_type, PCWSTR, IUnknown *,
ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler *);
using DllCanUnloadNow_t = HRESULT(STDMETHODCALLTYPE *)();
static constexpr auto CreateWebViewEnvironmentWithOptionsInternal =
library_symbol<CreateWebViewEnvironmentWithOptionsInternal_t>(
"CreateWebViewEnvironmentWithOptionsInternal");
static constexpr auto DllCanUnloadNow =
library_symbol<DllCanUnloadNow_t>("DllCanUnloadNow");
} // namespace webview2_symbols
#endif /* WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL */
#if WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK == 1
namespace webview2_symbols {
using CreateCoreWebView2EnvironmentWithOptions_t = HRESULT(STDMETHODCALLTYPE *)(
PCWSTR, PCWSTR, ICoreWebView2EnvironmentOptions *,
ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler *);
using GetAvailableCoreWebView2BrowserVersionString_t =
HRESULT(STDMETHODCALLTYPE *)(PCWSTR, LPWSTR *);
static constexpr auto CreateCoreWebView2EnvironmentWithOptions =
library_symbol<CreateCoreWebView2EnvironmentWithOptions_t>(
"CreateCoreWebView2EnvironmentWithOptions");
static constexpr auto GetAvailableCoreWebView2BrowserVersionString =
library_symbol<GetAvailableCoreWebView2BrowserVersionString_t>(
"GetAvailableCoreWebView2BrowserVersionString");
} // namespace webview2_symbols
#endif /* WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK */
class loader {
public:
HRESULT create_environment_with_options(
PCWSTR browser_dir, PCWSTR user_data_dir,
ICoreWebView2EnvironmentOptions *env_options,
ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler
*created_handler) const {
#if WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK == 1
if (m_lib.is_loaded()) {
if (auto fn = m_lib.get(
webview2_symbols::CreateCoreWebView2EnvironmentWithOptions)) {
return fn(browser_dir, user_data_dir, env_options, created_handler);
}
}
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1
return create_environment_with_options_impl(browser_dir, user_data_dir,
env_options, created_handler);
#else
return S_FALSE;
#endif
#else
return ::CreateCoreWebView2EnvironmentWithOptions(
browser_dir, user_data_dir, env_options, created_handler);
#endif /* WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK */
}
HRESULT
get_available_browser_version_string(PCWSTR browser_dir,
LPWSTR *version) const {
#if WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK == 1
if (m_lib.is_loaded()) {
if (auto fn = m_lib.get(
webview2_symbols::GetAvailableCoreWebView2BrowserVersionString)) {
return fn(browser_dir, version);
}
}
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1
return get_available_browser_version_string_impl(browser_dir, version);
#else
return S_FALSE;
#endif
#else
return ::GetAvailableCoreWebView2BrowserVersionString(browser_dir, version);
#endif /* WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK */
}
private:
#if WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL == 1
struct client_info_t {
bool found = false;
std::wstring dll_path;
std::wstring version;
webview2_runtime_type runtime_type;
};
HRESULT create_environment_with_options_impl(
PCWSTR browser_dir, PCWSTR user_data_dir,
ICoreWebView2EnvironmentOptions *env_options,
ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler
*created_handler) const {
auto found_client = find_available_client(browser_dir);
if (!found_client.found) {
return -1;
}
auto client_dll = native_library(found_client.dll_path.c_str());
if (auto fn = client_dll.get(
webview2_symbols::CreateWebViewEnvironmentWithOptionsInternal)) {
return fn(true, found_client.runtime_type, user_data_dir, env_options,
created_handler);
}
if (auto fn = client_dll.get(webview2_symbols::DllCanUnloadNow)) {
if (!fn()) {
client_dll.detach();
}
}
return ERROR_SUCCESS;
}
HRESULT
get_available_browser_version_string_impl(PCWSTR browser_dir,
LPWSTR *version) const {
if (!version) {
return -1;
}
auto found_client = find_available_client(browser_dir);
if (!found_client.found) {
return -1;
}
auto info_length_bytes =
found_client.version.size() * sizeof(found_client.version[0]);
auto info = static_cast<LPWSTR>(CoTaskMemAlloc(info_length_bytes));
if (!info) {
return -1;
}
CopyMemory(info, found_client.version.c_str(), info_length_bytes);
*version = info;
return 0;
}
client_info_t find_available_client(PCWSTR browser_dir) const {
if (browser_dir) {
return find_embedded_client(api_version, browser_dir);
}
auto found_client =
find_installed_client(api_version, true, default_release_channel_guid);
if (!found_client.found) {
found_client = find_installed_client(api_version, false,
default_release_channel_guid);
}
return found_client;
}
std::wstring make_client_dll_path(const std::wstring &dir) const {
auto dll_path = dir;
if (!dll_path.empty()) {
auto last_char = dir[dir.size() - 1];
if (last_char != L'\\' && last_char != L'/') {
dll_path += L'\\';
}
}
dll_path += L"EBWebView\\";
#if defined(_M_X64) || defined(__x86_64__)
dll_path += L"x64";
#elif defined(_M_IX86) || defined(__i386__)
dll_path += L"x86";
#elif defined(_M_ARM64) || defined(__aarch64__)
dll_path += L"arm64";
#else
#error WebView2 integration for this platform is not yet supported.
#endif
dll_path += L"\\EmbeddedBrowserWebView.dll";
return dll_path;
}
client_info_t
find_installed_client(unsigned int min_api_version, bool system,
const std::wstring &release_channel) const {
std::wstring sub_key = client_state_reg_sub_key;
sub_key += release_channel;
auto root_key = system ? HKEY_LOCAL_MACHINE : HKEY_CURRENT_USER;
reg_key key(root_key, sub_key, 0, KEY_READ | KEY_WOW64_32KEY);
if (!key.is_open()) {
return {};
}
auto ebwebview_value = key.query_string(L"EBWebView");
auto client_version_string =
get_last_native_path_component(ebwebview_value);
auto client_version = parse_version(client_version_string);
if (client_version[2] < min_api_version) {
// Our API version is greater than the runtime API version.
return {};
}
auto client_dll_path = make_client_dll_path(ebwebview_value);
return {true, client_dll_path, client_version_string,
webview2_runtime_type::installed};
}
client_info_t find_embedded_client(unsigned int min_api_version,
const std::wstring &dir) const {
auto client_dll_path = make_client_dll_path(dir);
auto client_version_string = get_file_version_string(client_dll_path);
auto client_version = parse_version(client_version_string);
if (client_version[2] < min_api_version) {
// Our API version is greater than the runtime API version.
return {};
}
return {true, client_dll_path, client_version_string,
webview2_runtime_type::embedded};
}
// The minimum WebView2 API version we need regardless of the SDK release
// actually used. The number comes from the SDK release version,
// e.g. 1.0.1150.38. To be safe the SDK should have a number that is greater
// than or equal to this number. The Edge browser webview client must
// have a number greater than or equal to this number.
static constexpr unsigned int api_version = 1150;
static constexpr auto client_state_reg_sub_key =
L"SOFTWARE\\Microsoft\\EdgeUpdate\\ClientState\\";
// GUID for the stable release channel.
static constexpr auto stable_release_guid =
L"{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}";
static constexpr auto default_release_channel_guid = stable_release_guid;
#endif /* WEBVIEW_MSWEBVIEW2_BUILTIN_IMPL */
#if WEBVIEW_MSWEBVIEW2_EXPLICIT_LINK == 1
native_library m_lib{L"WebView2Loader.dll"};
#endif
};
namespace cast_info {
static constexpr auto controller_completed =
cast_info_t<ICoreWebView2CreateCoreWebView2ControllerCompletedHandler>{
IID_ICoreWebView2CreateCoreWebView2ControllerCompletedHandler};
static constexpr auto environment_completed =
cast_info_t<ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler>{
IID_ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler};
static constexpr auto message_received =
cast_info_t<ICoreWebView2WebMessageReceivedEventHandler>{
IID_ICoreWebView2WebMessageReceivedEventHandler};
static constexpr auto permission_requested =
cast_info_t<ICoreWebView2PermissionRequestedEventHandler>{
IID_ICoreWebView2PermissionRequestedEventHandler};
static constexpr auto new_window_requested =
cast_info_t<ICoreWebView2NewWindowRequestedEventHandler>{
IID_ICoreWebView2NewWindowRequestedEventHandler};
static constexpr auto navigation_starting =
cast_info_t<ICoreWebView2NavigationStartingEventHandler>{
IID_ICoreWebView2NavigationStartingEventHandler};
static constexpr auto script_dialog_opening =
cast_info_t<ICoreWebView2ScriptDialogOpeningEventHandler>{
IID_ICoreWebView2ScriptDialogOpeningEventHandler};
static constexpr auto context_menu_requested =
cast_info_t<ICoreWebView2ContextMenuRequestedEventHandler>{
IID_ICoreWebView2ContextMenuRequestedEventHandler};
} // namespace cast_info
} // namespace mswebview2
class webview2_com_handler
: public ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler,
public ICoreWebView2CreateCoreWebView2ControllerCompletedHandler,
public ICoreWebView2WebMessageReceivedEventHandler,
public ICoreWebView2PermissionRequestedEventHandler,
public ICoreWebView2NewWindowRequestedEventHandler,
public ICoreWebView2NavigationStartingEventHandler,
public ICoreWebView2ScriptDialogOpeningEventHandler,
public ICoreWebView2ContextMenuRequestedEventHandler,
public ICoreWebView2CustomItemSelectedEventHandler {
using webview2_com_handler_cb_t =
std::function<void(ICoreWebView2Controller *, ICoreWebView2_11 *webview)>;
public:
webview2_com_handler(HWND hwnd, msg_cb_t msgCb, webview2_com_handler_cb_t cb)
: m_window(hwnd), m_msgCb(msgCb), m_cb(cb) {}
virtual ~webview2_com_handler() = default;
webview2_com_handler(const webview2_com_handler &other) = delete;
webview2_com_handler &operator=(const webview2_com_handler &other) = delete;
webview2_com_handler(webview2_com_handler &&other) = delete;
webview2_com_handler &operator=(webview2_com_handler &&other) = delete;
ULONG STDMETHODCALLTYPE AddRef() { return ++m_ref_count; }
ULONG STDMETHODCALLTYPE Release() {
if (m_ref_count > 1) {
return --m_ref_count;
}
delete this;
return 0;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2 *sender,
ICoreWebView2NewWindowRequestedEventArgs *args) {
LPWSTR uri;
args->get_Uri(&uri);
if (uri) {
std::wstring wuri(uri);
std::string uri_str = narrow_string(wuri);
std::string uri_lower = uri_str;
std::transform(uri_lower.begin(), uri_lower.end(), uri_lower.begin(), ::tolower);
bool is_localhost =
uri_lower.find("localhost") != std::string::npos ||
uri_lower.find("127.0.0.1") != std::string::npos ||
uri_lower.find("::1") != std::string::npos;
args->put_Handled(TRUE);
if (is_localhost) {
std::string path = "/";
size_t protocol_end = uri_str.find("://");
if (protocol_end != std::string::npos) {
size_t path_start = uri_str.find("/", protocol_end + 3);
if (path_start != std::string::npos) {
path = uri_str.substr(path_start);
}
}
std::string js = "history.pushState({}, '', '" + path + "'); window.dispatchEvent(new PopStateEvent('popstate'));";
std::wstring wjs = widen_string(js);
sender->ExecuteScript(wjs.c_str(), nullptr);
} else {
ShellExecuteW(nullptr, L"open", uri, nullptr, nullptr, SW_SHOWNORMAL);
}
CoTaskMemFree(uri);
}
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2 *sender,
ICoreWebView2NavigationStartingEventArgs *args) {
LPWSTR uri;
args->get_Uri(&uri);
if (uri) {
std::wstring wuri(uri);
std::string uri_str = narrow_string(wuri);
std::string uri_lower = uri_str;
std::transform(uri_lower.begin(), uri_lower.end(), uri_lower.begin(), ::tolower);
bool is_external =
(uri_lower.find("http://") == 0 || uri_lower.find("https://") == 0) &&
uri_lower.find("localhost") == std::string::npos &&
uri_lower.find("127.0.0.1") == std::string::npos &&
uri_lower.find("::1") == std::string::npos;
if (is_external) {
args->put_Cancel(TRUE);
ShellExecuteW(nullptr, L"open", uri, nullptr, nullptr, SW_SHOWNORMAL);
}
CoTaskMemFree(uri);
}
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2 *sender,
ICoreWebView2ScriptDialogOpeningEventArgs *args) {
LPWSTR message;
COREWEBVIEW2_SCRIPT_DIALOG_KIND kind;
args->get_Message(&message);
args->get_Kind(&kind);
if (kind == COREWEBVIEW2_SCRIPT_DIALOG_KIND_CONFIRM) {
int result = MessageBoxW(m_window, message, L"Confirm", MB_OKCANCEL | MB_ICONQUESTION);
if (result == IDOK) {
args->Accept();
}
CoTaskMemFree(message);
SetFocus(m_window);
return S_OK;
}
CoTaskMemFree(message);
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2 *sender,
ICoreWebView2ContextMenuRequestedEventArgs *args) {
ICoreWebView2ContextMenuItemCollection* items = nullptr;
if (FAILED(args->get_MenuItems(&items)) || !items) {
return S_OK;
}
UINT count = 0;
items->get_Count(&count);
// remove unwanted default items
for (int idx = static_cast<int>(count) - 1; idx >= 0; --idx) {
ICoreWebView2ContextMenuItem* item = nullptr;
items->GetValueAtIndex(idx, &item);
if (!item) continue;
LPWSTR name = nullptr;
item->get_Name(&name);
if (name != nullptr) {
std::wstring itemName(name);
std::transform(itemName.begin(), itemName.end(), itemName.begin(), ::towlower);
if (itemName.find(L"saveas") != std::wstring::npos ||
itemName.find(L"print") != std::wstring::npos ||
itemName.find(L"copylinktohighlight") != std::wstring::npos ||
itemName.find(L"back") != std::wstring::npos ||
itemName.find(L"forward") != std::wstring::npos ||
itemName.find(L"reload") != std::wstring::npos ||
itemName.find(L"share") != std::wstring::npos ||
itemName.find(L"screenshot") != std::wstring::npos ||
itemName.find(L"webcapture") != std::wstring::npos ||
itemName.find(L"openlinkinnewwindow") != std::wstring::npos ||
itemName.find(L"savelinkas") != std::wstring::npos ||
itemName.find(L"copylinklocation") != std::wstring::npos) {
items->RemoveValueAtIndex(idx);
}
}
if (name) {
CoTaskMemFree(name);
}
if (item) {
item->Release();
}
}
items->get_Count(&count);
// Add custom menu items at the top
int item_count = menu_get_item_count();
if (item_count > 0) {
// std::wstring message = L"Found " + std::to_wstring(customItemCount) + L" custom menu items";
// MessageBoxW(m_window, message.c_str(), L"Custom Menu Items", MB_OK | MB_ICONINFORMATION);
menuItem* customItems = static_cast<menuItem*>(menu_get_items());
if (customItems) {
ICoreWebView2_11 *webview = nullptr;
if (sender) {
sender->QueryInterface(IID_ICoreWebView2_11,
reinterpret_cast<void **>(&webview));
}
ICoreWebView2Environment* env_base = nullptr;
ICoreWebView2Environment9* env = nullptr;
if (SUCCEEDED(webview->get_Environment(&env_base)) && env_base) {
env_base->QueryInterface(IID_ICoreWebView2Environment9, (void**)&env);
env_base->Release();
int insertIndex = 0;
for (int i = 0; i < item_count; i++) {
if (customItems[i].separator) {
ICoreWebView2ContextMenuItem* separator = nullptr;
env->CreateContextMenuItem(L"", nullptr, COREWEBVIEW2_CONTEXT_MENU_ITEM_KIND_SEPARATOR, &separator);
if (separator) {
items->InsertValueAtIndex(insertIndex++, separator);
separator->Release();
}
} else if (customItems[i].label) {
std::string label(customItems[i].label);
std::wstring wlabel = widen_string(label);
ICoreWebView2ContextMenuItem* menuItem = nullptr;
env->CreateContextMenuItem(wlabel.c_str(), nullptr, COREWEBVIEW2_CONTEXT_MENU_ITEM_KIND_COMMAND, &menuItem);
if (menuItem) {
menuItem->put_IsEnabled(customItems[i].enabled);
// Add click handler
EventRegistrationToken token;
menuItem->add_CustomItemSelected(this, &token);
// Insert at the top (current insertIndex)
items->InsertValueAtIndex(insertIndex++, menuItem);
menuItem->Release();
}
}
}
items->get_Count(&count);
if (count > 0) {
ICoreWebView2ContextMenuItem* separator = nullptr;
env->CreateContextMenuItem(L"", nullptr, COREWEBVIEW2_CONTEXT_MENU_ITEM_KIND_SEPARATOR, &separator);
if (separator) {
items->InsertValueAtIndex(insertIndex, separator);
separator->Release();
}
}
env->Release();
}
}
}
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2ContextMenuItem* sender,
IUnknown* /*args*/) {
LPWSTR label = nullptr;
if (sender && SUCCEEDED(sender->get_Label(&label)) && label) {
std::wstring wlabel(label);
std::string itemLabel = narrow_string(wlabel);
menu_handle_selection((char*)itemLabel.c_str());
CoTaskMemFree(label);
}
return S_OK;
}
HRESULT STDMETHODCALLTYPE QueryInterface(REFIID riid, LPVOID *ppv) {
using namespace mswebview2::cast_info;
if (!ppv) {
return E_POINTER;
}
// All of the COM interfaces we implement should be added here regardless
// of whether they are required.
// This is just to be on the safe side in case the WebView2 Runtime ever
// requests a pointer to an interface we implement.
// The WebView2 Runtime must at the very least be able to get a pointer to
// ICoreWebView2CreateCoreWebView2EnvironmentCompletedHandler when we use
// our custom WebView2 loader implementation, and observations have shown
// that it is the only interface requested in this case. None have been
// observed to be requested when using the official WebView2 loader.
if (cast_if_equal_iid(riid, controller_completed, ppv) ||
cast_if_equal_iid(riid, environment_completed, ppv) ||
cast_if_equal_iid(riid, message_received, ppv) ||
cast_if_equal_iid(riid, permission_requested, ppv) ||
cast_if_equal_iid(riid, new_window_requested, ppv) ||
cast_if_equal_iid(riid, navigation_starting, ppv) ||
cast_if_equal_iid(riid, script_dialog_opening, ppv) ||
cast_if_equal_iid(riid, context_menu_requested, ppv)) {
return S_OK;
}
return E_NOINTERFACE;
}
HRESULT STDMETHODCALLTYPE Invoke(HRESULT res, ICoreWebView2Environment *env) {
if (SUCCEEDED(res)) {
res = env->CreateCoreWebView2Controller(m_window, this);
if (SUCCEEDED(res)) {
return S_OK;
}
}
try_create_environment();
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(HRESULT res,
ICoreWebView2Controller *controller) {
if (FAILED(res)) {
// See try_create_environment() regarding
// HRESULT_FROM_WIN32(ERROR_INVALID_STATE).
// The result is E_ABORT if the parent window has been destroyed already.
switch (res) {
case HRESULT_FROM_WIN32(ERROR_INVALID_STATE):
case E_ABORT:
return S_OK;
}
try_create_environment();
return S_OK;
}
ICoreWebView2 *webview_base = nullptr;
controller->get_CoreWebView2(&webview_base);
ICoreWebView2_11 *webview = nullptr;
if (webview_base) {
webview_base->QueryInterface(IID_ICoreWebView2_11,
reinterpret_cast<void **>(&webview));
webview_base->Release();
}
if (!webview) {
return E_FAIL;
}
::EventRegistrationToken token;
webview->add_WebMessageReceived(this, &token);
webview->add_PermissionRequested(this, &token);
webview->add_NewWindowRequested(this, &token);
webview->add_NavigationStarting(this, &token);
webview->add_ScriptDialogOpening(this, &token);
webview->add_ContextMenuRequested(this, &token);
m_cb(controller, webview);
return S_OK;
}
HRESULT STDMETHODCALLTYPE Invoke(
ICoreWebView2 *sender, ICoreWebView2WebMessageReceivedEventArgs *args) {
LPWSTR message;
args->TryGetWebMessageAsString(&message);
m_msgCb(narrow_string(message));
sender->PostWebMessageAsString(message);
CoTaskMemFree(message);
return S_OK;
}
HRESULT STDMETHODCALLTYPE
Invoke(ICoreWebView2 * /*sender*/,
ICoreWebView2PermissionRequestedEventArgs *args) {
COREWEBVIEW2_PERMISSION_KIND kind;
args->get_PermissionKind(&kind);
if (kind == COREWEBVIEW2_PERMISSION_KIND_CLIPBOARD_READ) {
args->put_State(COREWEBVIEW2_PERMISSION_STATE_ALLOW);
}
return S_OK;
}
// Checks whether the specified IID equals the IID of the specified type and
// if so casts the "this" pointer to T and returns it. Returns nullptr on
// mismatching IIDs.
// If ppv is specified then the pointer will also be assigned to *ppv.
template <typename T>
T *cast_if_equal_iid(REFIID riid, const cast_info_t<T> &info,
LPVOID *ppv = nullptr) noexcept {
T *ptr = nullptr;
if (IsEqualIID(riid, info.iid)) {
ptr = static_cast<T *>(this);
ptr->AddRef();
}
if (ppv) {
*ppv = ptr;
}
return ptr;
}
// Set the function that will perform the initiating logic for creating
// the WebView2 environment.
void set_attempt_handler(std::function<HRESULT()> attempt_handler) noexcept {
m_attempt_handler = attempt_handler;
}
// Retry creating a WebView2 environment.
// The initiating logic for creating the environment is defined by the
// caller of set_attempt_handler().
void try_create_environment() noexcept {
// WebView creation fails with HRESULT_FROM_WIN32(ERROR_INVALID_STATE) if
// a running instance using the same user data folder exists, and the
// Environment objects have different EnvironmentOptions.
// Source: https://docs.microsoft.com/en-us/microsoft-edge/webview2/reference/win32/icorewebview2environment?view=webview2-1.0.1150.38
if (m_attempts < m_max_attempts) {
++m_attempts;
auto res = m_attempt_handler();
if (SUCCEEDED(res)) {
return;
}
// Not entirely sure if this error code only applies to
// CreateCoreWebView2Controller so we check here as well.
if (res == HRESULT_FROM_WIN32(ERROR_INVALID_STATE)) {
return;
}
try_create_environment();
return;
}
// Give up.
m_cb(nullptr, nullptr);
}
private:
HWND m_window;
msg_cb_t m_msgCb;
webview2_com_handler_cb_t m_cb;
std::atomic<ULONG> m_ref_count{1};
std::function<HRESULT()> m_attempt_handler;
unsigned int m_max_attempts = 5;
unsigned int m_attempts = 0;
};
class win32_edge_engine : public engine_base {
public:
win32_edge_engine(bool debug, void *window) : m_owns_window{!window} {
if (!is_webview2_available()) {
return;
}
HINSTANCE hInstance = GetModuleHandle(nullptr);
if (m_owns_window) {
m_com_init = {COINIT_APARTMENTTHREADED};
if (!m_com_init.is_initialized()) {
return;
}
enable_dpi_awareness();
HICON icon = (HICON)LoadImage(
hInstance, IDI_APPLICATION, IMAGE_ICON, GetSystemMetrics(SM_CXICON),
GetSystemMetrics(SM_CYICON), LR_DEFAULTCOLOR);
// Create a top-level window.
WNDCLASSEXW wc;
ZeroMemory(&wc, sizeof(WNDCLASSEX));
wc.cbSize = sizeof(WNDCLASSEX);
wc.hInstance = hInstance;
wc.lpszClassName = L"webview";
wc.hIcon = icon;
wc.lpfnWndProc = (WNDPROC)(+[](HWND hwnd, UINT msg, WPARAM wp,
LPARAM lp) -> LRESULT {
win32_edge_engine *w{};
if (msg == WM_NCCREATE) {
auto *lpcs{reinterpret_cast<LPCREATESTRUCT>(lp)};
w = static_cast<win32_edge_engine *>(lpcs->lpCreateParams);
w->m_window = hwnd;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, reinterpret_cast<LONG_PTR>(w));
enable_non_client_dpi_scaling_if_needed(hwnd);
apply_window_theme(hwnd);
} else {
w = reinterpret_cast<win32_edge_engine *>(
GetWindowLongPtrW(hwnd, GWLP_USERDATA));
}
if (!w) {
return DefWindowProcW(hwnd, msg, wp, lp);
}
switch (msg) {
case WM_SIZE:
w->resize_widget();
break;
case WM_CLOSE:
DestroyWindow(hwnd);
break;
case WM_DESTROY:
w->m_window = nullptr;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, 0);
w->on_window_destroyed();
break;
case WM_GETMINMAXINFO: {
auto lpmmi = (LPMINMAXINFO)lp;
if (w->m_maxsz.x > 0 && w->m_maxsz.y > 0) {
lpmmi->ptMaxSize = w->m_maxsz;
lpmmi->ptMaxTrackSize = w->m_maxsz;
}
if (w->m_minsz.x > 0 && w->m_minsz.y > 0) {
lpmmi->ptMinTrackSize = w->m_minsz;
}
} break;
case 0x02E4 /*WM_GETDPISCALEDSIZE*/: {
auto dpi = static_cast<int>(wp);
auto *size{reinterpret_cast<SIZE *>(lp)};
*size = w->get_scaled_size(w->m_dpi, dpi);
return TRUE;
}
case 0x02E0 /*WM_DPICHANGED*/: {
// Windows 10: The size we get here is exactly what we supplied to WM_GETDPISCALEDSIZE.
// Windows 11: The size we get here is NOT what we supplied to WM_GETDPISCALEDSIZE.
// Due to this difference, don't use the suggested bounds.
auto dpi = static_cast<int>(HIWORD(wp));
w->on_dpi_changed(dpi);
break;
}
case WM_SETTINGCHANGE: {
auto *area = reinterpret_cast<const wchar_t *>(lp);
if (area) {
w->on_system_setting_change(area);
}
break;
}
case WM_ACTIVATE:
if (LOWORD(wp) != WA_INACTIVE) {
w->focus_webview();
}
break;
default:
return DefWindowProcW(hwnd, msg, wp, lp);
}
return 0;
});
RegisterClassExW(&wc);
CreateWindowW(L"webview", L"", WS_OVERLAPPEDWINDOW, CW_USEDEFAULT,
CW_USEDEFAULT, 0, 0, nullptr, nullptr, hInstance, this);
if (m_window == nullptr) {
return;
}
on_window_created();
m_dpi = get_window_dpi(m_window);
constexpr const int initial_width = 640;
constexpr const int initial_height = 480;
set_size(initial_width, initial_height, WEBVIEW_HINT_NONE);
} else {
m_window = IsWindow(static_cast<HWND>(window))
? static_cast<HWND>(window)
: *(static_cast<HWND *>(window));
m_dpi = get_window_dpi(m_window);
}
// Create a window that WebView2 will be embedded into.
WNDCLASSEXW widget_wc{};
widget_wc.cbSize = sizeof(WNDCLASSEX);
widget_wc.hInstance = hInstance;
widget_wc.lpszClassName = L"webview_widget";
widget_wc.lpfnWndProc = (WNDPROC)(+[](HWND hwnd, UINT msg, WPARAM wp,
LPARAM lp) -> LRESULT {
win32_edge_engine *w{};
if (msg == WM_NCCREATE) {
auto *lpcs{reinterpret_cast<LPCREATESTRUCT>(lp)};
w = static_cast<win32_edge_engine *>(lpcs->lpCreateParams);
w->m_widget = hwnd;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, reinterpret_cast<LONG_PTR>(w));
} else {
w = reinterpret_cast<win32_edge_engine *>(
GetWindowLongPtrW(hwnd, GWLP_USERDATA));
}
if (!w) {
return DefWindowProcW(hwnd, msg, wp, lp);
}
switch (msg) {
case WM_SIZE:
w->resize_webview();
break;
case WM_DESTROY:
w->m_widget = nullptr;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, 0);
break;
default:
return DefWindowProcW(hwnd, msg, wp, lp);
}
return 0;
});
RegisterClassExW(&widget_wc);
CreateWindowExW(WS_EX_CONTROLPARENT, L"webview_widget", nullptr, WS_CHILD,
0, 0, 0, 0, m_window, nullptr, hInstance, this);
// Create a message-only window for internal messaging.
WNDCLASSEXW message_wc{};
message_wc.cbSize = sizeof(WNDCLASSEX);
message_wc.hInstance = hInstance;
message_wc.lpszClassName = L"webview_message";
message_wc.lpfnWndProc = (WNDPROC)(+[](HWND hwnd, UINT msg, WPARAM wp,
LPARAM lp) -> LRESULT {
win32_edge_engine *w{};
if (msg == WM_NCCREATE) {
auto *lpcs{reinterpret_cast<LPCREATESTRUCT>(lp)};
w = static_cast<win32_edge_engine *>(lpcs->lpCreateParams);
w->m_message_window = hwnd;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, reinterpret_cast<LONG_PTR>(w));
} else {
w = reinterpret_cast<win32_edge_engine *>(
GetWindowLongPtrW(hwnd, GWLP_USERDATA));
}
if (!w) {
return DefWindowProcW(hwnd, msg, wp, lp);
}
switch (msg) {
case WM_APP:
if (auto f = (dispatch_fn_t *)(lp)) {
(*f)();
delete f;
}
break;
case WM_DESTROY:
w->m_message_window = nullptr;
SetWindowLongPtrW(hwnd, GWLP_USERDATA, 0);
break;
default:
return DefWindowProcW(hwnd, msg, wp, lp);
}
return 0;
});
RegisterClassExW(&message_wc);
CreateWindowExW(0, L"webview_message", nullptr, 0, 0, 0, 0, 0, HWND_MESSAGE,
nullptr, hInstance, this);
if (m_owns_window) {
// ShowWindow(m_window, SW_SHOW);
UpdateWindow(m_window);
SetFocus(m_window);
}
auto cb =
std::bind(&win32_edge_engine::on_message, this, std::placeholders::_1);
embed(m_widget, debug, cb);
}
virtual ~win32_edge_engine() {
if (m_com_handler) {
m_com_handler->Release();
m_com_handler = nullptr;
}
if (m_webview) {
m_webview->Release();
m_webview = nullptr;
}
if (m_controller) {
m_controller->Release();
m_controller = nullptr;
}
// Replace wndproc to avoid callbacks and other bad things during
// destruction.
auto wndproc = reinterpret_cast<LONG_PTR>(
+[](HWND hwnd, UINT msg, WPARAM wp, LPARAM lp) -> LRESULT {
return DefWindowProcW(hwnd, msg, wp, lp);
});
if (m_widget) {
SetWindowLongPtrW(m_widget, GWLP_WNDPROC, wndproc);
}
if (m_window && m_owns_window) {
SetWindowLongPtrW(m_window, GWLP_WNDPROC, wndproc);
}
if (m_widget) {
DestroyWindow(m_widget);
m_widget = nullptr;
}
if (m_window) {
if (m_owns_window) {
DestroyWindow(m_window);
on_window_destroyed(true);
}
m_window = nullptr;
}
if (m_owns_window) {
// Not strictly needed for windows to close immediately but aligns
// behavior across backends.
deplete_run_loop_event_queue();
}
// We need the message window in order to deplete the event queue.
if (m_message_window) {
SetWindowLongPtrW(m_message_window, GWLP_WNDPROC, wndproc);
DestroyWindow(m_message_window);
m_message_window = nullptr;
}
}
win32_edge_engine(const win32_edge_engine &other) = delete;
win32_edge_engine &operator=(const win32_edge_engine &other) = delete;
win32_edge_engine(win32_edge_engine &&other) = delete;
win32_edge_engine &operator=(win32_edge_engine &&other) = delete;
void run_impl() override {
MSG msg;
while (GetMessageW(&msg, nullptr, 0, 0) > 0) {
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
}
void *window_impl() override { return (void *)m_window; }
void *widget_impl() override { return (void *)m_widget; }
void *browser_controller_impl() override { return (void *)m_controller; }
void terminate_impl() override { PostQuitMessage(0); }
void dispatch_impl(dispatch_fn_t f) override {
PostMessageW(m_message_window, WM_APP, 0, (LPARAM) new dispatch_fn_t(f));
}
void set_title_impl(const std::string &title) override {
SetWindowTextW(m_window, widen_string(title).c_str());
}
void set_size_impl(int width, int height, webview_hint_t hints) override {
auto style = GetWindowLong(m_window, GWL_STYLE);
if (hints == WEBVIEW_HINT_FIXED) {
style &= ~(WS_THICKFRAME | WS_MAXIMIZEBOX);
} else {
style |= (WS_THICKFRAME | WS_MAXIMIZEBOX);
}
SetWindowLong(m_window, GWL_STYLE, style);
if (hints == WEBVIEW_HINT_MAX) {
m_maxsz.x = width;
m_maxsz.y = height;
} else if (hints == WEBVIEW_HINT_MIN) {
m_minsz.x = width;
m_minsz.y = height;
} else {
auto dpi = get_window_dpi(m_window);
m_dpi = dpi;
auto scaled_size =
scale_size(width, height, get_default_window_dpi(), dpi);
auto frame_size =
make_window_frame_size(m_window, scaled_size.cx, scaled_size.cy, dpi);
SetWindowPos(m_window, nullptr, 0, 0, frame_size.cx, frame_size.cy,
SWP_NOZORDER | SWP_NOACTIVATE | SWP_NOMOVE |
SWP_FRAMECHANGED);
}
}
void navigate_impl(const std::string &url) override {
auto wurl = widen_string(url);
m_webview->Navigate(wurl.c_str());
}
void init_impl(const std::string &js) override {
auto wjs = widen_string(js);
m_webview->AddScriptToExecuteOnDocumentCreated(wjs.c_str(), nullptr);
}
void eval_impl(const std::string &js) override {
auto wjs = widen_string(js);
m_webview->ExecuteScript(wjs.c_str(), nullptr);
}
void set_html_impl(const std::string &html) override {
m_webview->NavigateToString(widen_string(html).c_str());
}
void set_zoom_impl(double level) override {
if (m_controller) {
m_controller->put_ZoomFactor(level);
}
}
double get_zoom_impl() override {
double zoom = 1.0;
if (m_controller) {
m_controller->get_ZoomFactor(&zoom);
}
return zoom;
}
private:
bool embed(HWND wnd, bool debug, msg_cb_t cb) {
std::atomic_flag flag = ATOMIC_FLAG_INIT;
flag.test_and_set();
wchar_t currentExePath[MAX_PATH];
GetModuleFileNameW(nullptr, currentExePath, MAX_PATH);
wchar_t *currentExeName = PathFindFileNameW(currentExePath);
wchar_t dataPath[MAX_PATH];
if (!SUCCEEDED(
SHGetFolderPathW(nullptr, CSIDL_APPDATA, nullptr, 0, dataPath))) {
return false;
}
wchar_t userDataFolder[MAX_PATH];
PathCombineW(userDataFolder, dataPath, currentExeName);
m_com_handler = new webview2_com_handler(
wnd, cb,
[&](ICoreWebView2Controller *controller, ICoreWebView2_11 *webview) {
if (!controller || !webview) {
flag.clear();
return;
}
controller->AddRef();
webview->AddRef();
m_controller = controller;
m_webview = webview;
flag.clear();
});
m_com_handler->set_attempt_handler([&] {
return m_webview2_loader.create_environment_with_options(
nullptr, userDataFolder, nullptr, m_com_handler);
});
m_com_handler->try_create_environment();
// Pump the message loop until WebView2 has finished initialization.
bool got_quit_msg = false;
MSG msg;
while (flag.test_and_set() && GetMessageW(&msg, nullptr, 0, 0) >= 0) {
if (msg.message == WM_QUIT) {
got_quit_msg = true;
break;
}
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
if (got_quit_msg) {
return false;
}
if (!m_controller || !m_webview) {
return false;
}
ICoreWebView2Settings *settings = nullptr;
auto res = m_webview->get_Settings(&settings);
if (res != S_OK) {
return false;
}
res = settings->put_AreDevToolsEnabled(debug ? TRUE : FALSE);
if (res != S_OK) {
return false;
}
res = settings->put_IsStatusBarEnabled(FALSE);
if (res != S_OK) {
return false;
}
res = settings->put_AreDefaultScriptDialogsEnabled(FALSE);
if (res != S_OK) {
return false;
}
res = settings->put_IsZoomControlEnabled(FALSE);
if (res != S_OK) {
return false;
}
init("window.external={invoke:s=>window.chrome.webview.postMessage(s)}");
resize_webview();
m_controller->put_IsVisible(TRUE);
ShowWindow(m_widget, SW_SHOW);
UpdateWindow(m_widget);
if (m_owns_window) {
focus_webview();
}
return true;
}
void resize_widget() {
if (m_widget) {
RECT r{};
if (GetClientRect(GetParent(m_widget), &r)) {
MoveWindow(m_widget, r.left, r.top, r.right - r.left, r.bottom - r.top,
TRUE);
}
}
}
void resize_webview() {
if (m_widget && m_controller) {
RECT bounds{};
if (GetClientRect(m_widget, &bounds)) {
m_controller->put_Bounds(bounds);
}
}
}
void focus_webview() {
if (m_controller) {
m_controller->MoveFocus(COREWEBVIEW2_MOVE_FOCUS_REASON_PROGRAMMATIC);
}
}
bool is_webview2_available() const noexcept {
LPWSTR version_info = nullptr;
auto res = m_webview2_loader.get_available_browser_version_string(
nullptr, &version_info);
// The result will be equal to HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND)
// if the WebView2 runtime is not installed.
auto ok = SUCCEEDED(res) && version_info;
if (version_info) {
CoTaskMemFree(version_info);
}
return ok;
}
void on_dpi_changed(int dpi) {
auto scaled_size = get_scaled_size(m_dpi, dpi);
auto frame_size =
make_window_frame_size(m_window, scaled_size.cx, scaled_size.cy, dpi);
SetWindowPos(m_window, nullptr, 0, 0, frame_size.cx, frame_size.cy,
SWP_NOZORDER | SWP_NOACTIVATE | SWP_NOMOVE | SWP_FRAMECHANGED);
m_dpi = dpi;
}
SIZE get_size() const {
RECT bounds;
GetClientRect(m_window, &bounds);
auto width = bounds.right - bounds.left;
auto height = bounds.bottom - bounds.top;
return {width, height};
}
SIZE get_scaled_size(int from_dpi, int to_dpi) const {
auto size = get_size();
return scale_size(size.cx, size.cy, from_dpi, to_dpi);
}
void on_system_setting_change(const wchar_t *area) {
// Detect light/dark mode change in system.
if (lstrcmpW(area, L"ImmersiveColorSet") == 0) {
apply_window_theme(m_window);
}
}
// Blocks while depleting the run loop of events.
void deplete_run_loop_event_queue() {
bool done{};
dispatch([&] { done = true; });
while (!done) {
MSG msg;
if (GetMessageW(&msg, nullptr, 0, 0) > 0) {
TranslateMessage(&msg);
DispatchMessageW(&msg);
}
}
}
// The app is expected to call CoInitializeEx before
// CreateCoreWebView2EnvironmentWithOptions.
// Source: https://docs.microsoft.com/en-us/microsoft-edge/webview2/reference/win32/webview2-idl#createcorewebview2environmentwithoptions
com_init_wrapper m_com_init;
HWND m_window = nullptr;
HWND m_widget = nullptr;
HWND m_message_window = nullptr;
POINT m_minsz = POINT{0, 0};
POINT m_maxsz = POINT{0, 0};
DWORD m_main_thread = GetCurrentThreadId();
ICoreWebView2_11 *m_webview = nullptr;
ICoreWebView2Controller *m_controller = nullptr;
webview2_com_handler *m_com_handler = nullptr;
mswebview2::loader m_webview2_loader;
int m_dpi{};
bool m_owns_window{};
};
} // namespace detail
using browser_engine = detail::win32_edge_engine;
} // namespace webview
#endif /* WEBVIEW_GTK, WEBVIEW_COCOA, WEBVIEW_EDGE */
namespace webview {
using webview = browser_engine;
} // namespace webview
WEBVIEW_API webview_t webview_create(int debug, void *wnd) {
auto w = new webview::webview(debug, wnd);
if (!w->window()) {
delete w;
return nullptr;
}
return w;
}
WEBVIEW_API void webview_destroy(webview_t w) {
delete static_cast<webview::webview *>(w);
}
WEBVIEW_API void webview_run(webview_t w) {
static_cast<webview::webview *>(w)->run();
}
WEBVIEW_API void webview_terminate(webview_t w) {
static_cast<webview::webview *>(w)->terminate();
}
WEBVIEW_API void webview_dispatch(webview_t w, void (*fn)(webview_t, void *),
void *arg) {
static_cast<webview::webview *>(w)->dispatch([=]() { fn(w, arg); });
}
WEBVIEW_API void *webview_get_window(webview_t w) {
return static_cast<webview::webview *>(w)->window();
}
WEBVIEW_API void *webview_get_native_handle(webview_t w,
webview_native_handle_kind_t kind) {
auto *w_ = static_cast<webview::webview *>(w);
switch (kind) {
case WEBVIEW_NATIVE_HANDLE_KIND_UI_WINDOW:
return w_->window();
case WEBVIEW_NATIVE_HANDLE_KIND_UI_WIDGET:
return w_->widget();
case WEBVIEW_NATIVE_HANDLE_KIND_BROWSER_CONTROLLER:
return w_->browser_controller();
default:
return nullptr;
}
}
WEBVIEW_API void webview_set_title(webview_t w, const char *title) {
static_cast<webview::webview *>(w)->set_title(title);
}
WEBVIEW_API void webview_set_size(webview_t w, int width, int height,
webview_hint_t hints) {
static_cast<webview::webview *>(w)->set_size(width, height, hints);
}
WEBVIEW_API void webview_navigate(webview_t w, const char *url) {
static_cast<webview::webview *>(w)->navigate(url);
}
WEBVIEW_API void webview_set_html(webview_t w, const char *html) {
static_cast<webview::webview *>(w)->set_html(html);
}
WEBVIEW_API void webview_init(webview_t w, const char *js) {
static_cast<webview::webview *>(w)->init(js);
}
WEBVIEW_API void webview_eval(webview_t w, const char *js) {
static_cast<webview::webview *>(w)->eval(js);
}
WEBVIEW_API void webview_bind(webview_t w, const char *name,
void (*fn)(const char *seq, const char *req,
void *arg),
void *arg) {
static_cast<webview::webview *>(w)->bind(
name,
[=](const std::string &seq, const std::string &req, void *arg) {
fn(seq.c_str(), req.c_str(), arg);
},
arg);
}
WEBVIEW_API void webview_unbind(webview_t w, const char *name) {
static_cast<webview::webview *>(w)->unbind(name);
}
WEBVIEW_API void webview_return(webview_t w, const char *seq, int status,
const char *result) {
static_cast<webview::webview *>(w)->resolve(seq, status, result);
}
WEBVIEW_API const webview_version_info_t *webview_version(void) {
return &webview::detail::library_version_info;
}
WEBVIEW_API void webview_set_zoom(webview_t w, double level) {
static_cast<webview::webview *>(w)->set_zoom(level);
}
WEBVIEW_API double webview_get_zoom(webview_t w) {
return static_cast<webview::webview *>(w)->get_zoom();
}
#endif /* WEBVIEW_HEADER */
#endif /* __cplusplus */
#endif /* WEBVIEW_H */
......@@ -11,13 +11,13 @@ import (
"golang.org/x/sys/windows"
)
var quitOnce sync.Once
func (t *winTray) Run() {
nativeLoop()
}
var (
quitOnce sync.Once
UI_REQUEST_MSG_ID = WM_USER + 2
FOCUS_WINDOW_MSG_ID = WM_USER + 3
)
func nativeLoop() {
func (t *winTray) TrayRun() {
// Main message pump.
slog.Debug("starting event handling loop")
m := &struct {
......@@ -32,6 +32,21 @@ func nativeLoop() {
for {
ret, _, err := pGetMessage.Call(uintptr(unsafe.Pointer(m)), 0, 0, 0)
// Ignore WM_QUIT messages from the UI window, which shouldn't exit the main app
if m.Message == WM_QUIT && t.app.UIRunning() {
if t.app != nil {
slog.Debug("converting WM_QUIT to terminate call on webview")
t.app.UITerminate()
}
// Drain any other WM_QUIT messages
for {
ret, _, err = pGetMessage.Call(uintptr(unsafe.Pointer(m)), 0, 0, 0)
if m.Message != WM_QUIT {
break
}
}
}
// If the function retrieves a message other than WM_QUIT, the return value is nonzero.
// If the function retrieves the WM_QUIT message, the return value is zero.
// If there is an error, the return value is -1
......@@ -41,6 +56,7 @@ func nativeLoop() {
slog.Error(fmt.Sprintf("get message failure: %v", err))
return
case 0:
// slog.Debug("XXX tray run loop exiting from handling", "message", fmt.Sprintf("0x%x", m.Message), "wParam", fmt.Sprintf("0x%x", m.Wparam), "lParam", fmt.Sprintf("0x%x", m.Lparam))
return
default:
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
......@@ -52,59 +68,52 @@ func nativeLoop() {
// WindowProc callback function that processes messages sent to a window.
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms633573(v=vs.85).aspx
func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam uintptr) (lResult uintptr) {
const (
WM_RBUTTONUP = 0x0205
WM_LBUTTONUP = 0x0202
WM_COMMAND = 0x0111
WM_ENDSESSION = 0x0016
WM_CLOSE = 0x0010
WM_DESTROY = 0x0002
WM_MOUSEMOVE = 0x0200
WM_LBUTTONDOWN = 0x0201
)
// slog.Debug("XXX in winTray.wndProc", "message", fmt.Sprintf("0x%x", message), "wParam", fmt.Sprintf("0x%x", wParam), "lParam", fmt.Sprintf("0x%x", lParam))
switch message {
case WM_COMMAND:
menuItemId := int32(wParam)
// https://docs.microsoft.com/en-us/windows/win32/menurc/wm-command#menus
switch menuItemId {
case quitMenuID:
select {
case t.callbacks.Quit <- struct{}{}:
// should not happen but in case not listening
default:
slog.Error("no listener on Quit")
}
t.app.Quit()
case updateMenuID:
select {
case t.callbacks.Update <- struct{}{}:
// should not happen but in case not listening
default:
slog.Error("no listener on Update")
}
t.app.DoUpdate()
case openUIMenuID:
// UI must be initialized on this thread so don't use the callbacks
t.app.UIShow()
case settingsUIMenuID:
// UI must be initialized on this thread so don't use the callbacks
t.app.UIRun("/settings")
case diagLogsMenuID:
select {
case t.callbacks.ShowLogs <- struct{}{}:
// should not happen but in case not listening
default:
slog.Error("no listener on ShowLogs")
}
t.showLogs()
default:
slog.Debug(fmt.Sprintf("Unexpected menu item id: %d", menuItemId))
lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd),
uintptr(message),
wParam,
lParam,
)
}
case WM_CLOSE:
// TODO - does this need adjusting?
// slog.Debug("XXX WM_CLOSE triggered")
boolRet, _, err := pDestroyWindow.Call(uintptr(t.window))
if boolRet == 0 {
slog.Error(fmt.Sprintf("failed to destroy window: %s", err))
}
err = t.wcex.unregister()
if err != nil {
slog.Error(fmt.Sprintf("failed to unregister window %s", err))
slog.Error(fmt.Sprintf("failed to uregister windo %s", err))
}
case WM_DESTROY:
// slog.Debug("XXX WM_DESTROY triggered")
// TODO - does this need adjusting?
// same as WM_ENDSESSION, but throws 0 exit code after all
defer pPostQuitMessage.Call(uintptr(int32(0))) //nolint:errcheck
fallthrough
case WM_ENDSESSION:
// slog.Debug("XXX WM_ENDSESSION triggered")
t.muNID.Lock()
if t.nid != nil {
err := t.nid.delete()
......@@ -124,25 +133,20 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
}
case 0x405: // TODO - how is this magic value derived for the notification left click
if t.pendingUpdate {
select {
case t.callbacks.Update <- struct{}{}:
// should not happen but in case not listening
default:
slog.Error("no listener on Update")
}
} else {
select {
case t.callbacks.DoFirstUse <- struct{}{}:
// should not happen but in case not listening
default:
slog.Error("no listener on DoFirstUse")
}
// TODO - revamp how detecting an update is notified to the user
t.app.DoUpdate()
}
case 0x404: // Middle click or close notification
// slog.Debug("doing nothing on close of first time notification")
default:
// 0x402 also seems common - what is it?
slog.Debug(fmt.Sprintf("unmanaged app message, lParm: 0x%x", lParam))
lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd),
uintptr(message),
wParam,
lParam,
)
}
case t.wmTaskbarCreated: // on explorer.exe restarts
t.muNID.Lock()
......@@ -151,9 +155,38 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
slog.Error(fmt.Sprintf("failed to refresh the taskbar on explorer restart: %s", err))
}
t.muNID.Unlock()
case uint32(UI_REQUEST_MSG_ID):
// Requests for the UI must always come from the main event thread
l := int(wParam)
path := unsafe.String((*byte)(unsafe.Pointer(lParam)), l)
t.app.UIRun(path)
case WM_COPYDATA:
// Handle URL scheme requests from other instances
if lParam != 0 {
cds := (*COPYDATASTRUCT)(unsafe.Pointer(lParam))
if cds.DwData == 1 { // Our identifier for URL scheme messages
// Convert the data back to string
data := make([]byte, cds.CbData)
copy(data, (*[1 << 30]byte)(unsafe.Pointer(cds.LpData))[:cds.CbData:cds.CbData])
urlScheme := string(data)
handleURLSchemeRequest(urlScheme)
lResult = 1 // Return non-zero to indicate success
}
}
case uint32(FOCUS_WINDOW_MSG_ID):
// Handle focus window request from another instance
if t.app.UIRunning() {
// If UI is already running, just show it
t.app.UIShow()
} else {
// If UI is not running, start it
t.app.UIRun("/")
}
lResult = 1 // Return non-zero to indicate success
default:
// Calls the default window procedure to provide default processing for any window messages that an application does not process.
// https://msdn.microsoft.com/en-us/library/windows/desktop/ms633572(v=vs.85).aspx
// slog.Debug("XXX passing through", "message", fmt.Sprintf("0x%x", message), "wParam", fmt.Sprintf("0x%x", wParam), "lParam", fmt.Sprintf("0x%x", lParam))
lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd),
uintptr(message),
......@@ -165,9 +198,23 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
}
func (t *winTray) Quit() {
// slog.Debug("XXX in winTray.Quit")
t.quitting = true
quitOnce.Do(quit)
}
func SendUIRequestMessage(path string) {
boolRet, _, err := pPostMessage.Call(
uintptr(wt.window),
uintptr(UI_REQUEST_MSG_ID),
uintptr(len(path)),
uintptr(unsafe.Pointer(unsafe.StringData(path))),
)
if boolRet == 0 {
slog.Error(fmt.Sprintf("failed to post UI request message %s", err))
}
}
func quit() {
boolRet, _, err := pPostMessage.Call(
uintptr(wt.window),
......@@ -179,3 +226,106 @@ func quit() {
slog.Error(fmt.Sprintf("failed to post close message on shutdown %s", err))
}
}
// findExistingInstance attempts to find an existing Ollama instance window
// Returns the window handle if found, 0 if not found
func findExistingInstance() uintptr {
classNamePtr, err := windows.UTF16PtrFromString(ClassName)
if err != nil {
slog.Error("failed to convert class name to UTF16", "error", err)
return 0
}
hwnd, _, _ := pFindWindow.Call(
uintptr(unsafe.Pointer(classNamePtr)),
0, // window name (null = any)
)
return hwnd
}
// CheckAndSendToExistingInstance attempts to send a URL scheme to an existing instance
// Returns true if successfully sent to existing instance, false if no instance found
func CheckAndSendToExistingInstance(urlScheme string) bool {
hwnd := findExistingInstance()
if hwnd == 0 {
// No existing window found
return false
}
data := []byte(urlScheme)
cds := COPYDATASTRUCT{
DwData: 1, // 1 to identify URL scheme messages
CbData: uint32(len(data)),
LpData: uintptr(unsafe.Pointer(&data[0])),
}
result, _, err := pSendMessage.Call(
hwnd,
uintptr(WM_COPYDATA),
0, // wParam is handle to sending window (0 is ok)
uintptr(unsafe.Pointer(&cds)),
)
// SendMessage returns the result from the window procedure
// For WM_COPYDATA, non-zero means success
if result == 0 {
slog.Error("failed to send URL scheme message to existing instance", "error", err)
return false
}
return true
}
// handleURLSchemeRequest processes a URL scheme request
func handleURLSchemeRequest(urlScheme string) {
if urlScheme == "" {
slog.Warn("empty URL scheme request")
return
}
// Call the app callback to handle URL scheme requests
// This will delegate to the main app logic
if wt.app != nil {
if urlHandler, ok := wt.app.(URLSchemeHandler); ok {
urlHandler.HandleURLScheme(urlScheme)
} else {
slog.Warn("app does not implement URLSchemeHandler interface")
}
} else {
slog.Warn("wt.app is nil")
}
}
// CheckAndFocusExistingInstance attempts to find an existing instance and optionally focus it
// Returns true if an existing instance was found, false otherwise
func CheckAndFocusExistingInstance(shouldFocus bool) bool {
hwnd := findExistingInstance()
if hwnd == 0 {
// No existing window found
return false
}
if !shouldFocus {
slog.Info("existing instance found, not focusing due to startHidden")
return true
}
// Send focus message to existing instance
result, _, err := pSendMessage.Call(
hwnd,
uintptr(FOCUS_WINDOW_MSG_ID),
0, // wParam not used
0, // lParam not used
)
// SendMessage returns the result from the window procedure
// For our custom message, non-zero means success
if result == 0 {
slog.Error("failed to send focus message to existing instance", "error", err)
return false
}
slog.Info("sent focus request to existing instance")
return true
}
......@@ -5,6 +5,10 @@ package wintray
import (
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
"syscall"
"unsafe"
"golang.org/x/sys/windows"
......@@ -12,6 +16,9 @@ import (
const (
_ = iota
openUIMenuID
settingsUIMenuID
updateSeparatorMenuID
updateAvailableMenuID
updateMenuID
separatorMenuID
......@@ -21,14 +28,21 @@ const (
)
func (t *winTray) initMenus() error {
if err := t.addOrUpdateMenuItem(openUIMenuID, 0, openUIMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(settingsUIMenuID, 0, settingsUIMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(diagLogsMenuID, 0, diagLogsMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
}
if err := t.addSeparatorMenuItem(diagSeparatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(quitMenuID, 0, quitMenuTitle, false); err != nil {
return fmt.Errorf("unable to create menu entries %w\n", err)
return fmt.Errorf("unable to create menu entries %w", err)
}
return nil
}
......@@ -36,6 +50,9 @@ func (t *winTray) initMenus() error {
func (t *winTray) UpdateAvailable(ver string) error {
if !t.updateNotified {
slog.Debug("updating menu and sending notification for new update")
if err := t.addSeparatorMenuItem(updateSeparatorMenuID, 0); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
if err := t.addOrUpdateMenuItem(updateAvailableMenuID, 0, updateAvailableMenuTitle, true); err != nil {
return fmt.Errorf("unable to create menu entries %w", err)
}
......@@ -70,3 +87,17 @@ func (t *winTray) UpdateAvailable(ver string) error {
}
return nil
}
func (t *winTray) showLogs() error {
localAppData := os.Getenv("LOCALAPPDATA")
AppDataDir := filepath.Join(localAppData, "Ollama")
cmd_path := "c:\\Windows\\system32\\cmd.exe"
slog.Debug(fmt.Sprintf("viewing logs with start %s", AppDataDir))
cmd := exec.Command(cmd_path, "/c", "start", AppDataDir)
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: false, CreationFlags: 0x08000000}
err := cmd.Start()
if err != nil {
slog.Error(fmt.Sprintf("Failed to open log dir: %s", err))
}
return nil
}
......@@ -12,4 +12,6 @@ const (
updateAvailableMenuTitle = "An update is available"
updateMenuTitle = "Restart to update"
diagLogsMenuTitle = "View logs"
openUIMenuTitle = "Open Ollama"
settingsUIMenuTitle = "Settings..."
)
......@@ -14,17 +14,56 @@ import (
"syscall"
"unsafe"
"github.com/ollama/ollama/app/assets"
"golang.org/x/sys/windows"
)
"github.com/ollama/ollama/app/tray/commontray"
const (
UpdateIconName = "tray_upgrade.ico"
IconName = "tray.ico"
ClassName = "OllamaClass"
)
func NewTray(app AppCallbacks) (TrayCallbacks, error) {
updateIcon, err := assets.GetIcon(UpdateIconName)
if err != nil {
return nil, fmt.Errorf("failed to load icon %s: %w", UpdateIconName, err)
}
icon, err := assets.GetIcon(IconName)
if err != nil {
return nil, fmt.Errorf("failed to load icon %s: %w", IconName, err)
}
return InitTray(icon, updateIcon, app)
}
type TrayCallbacks interface {
Quit()
TrayRun()
UpdateAvailable(ver string) error
GetIconHandle() windows.Handle
}
type AppCallbacks interface {
UIRun(path string)
UIShow()
UITerminate()
UIRunning() bool
Quit()
DoUpdate()
}
type URLSchemeHandler interface {
HandleURLScheme(urlScheme string)
}
// Helpful sources: https://github.com/golang/exp/blob/master/shiny/driver/internal/win32
// Contains information about loaded resources
type winTray struct {
instance,
icon,
defaultIcon,
cursor,
window windows.Handle
......@@ -54,25 +93,21 @@ type winTray struct {
pendingUpdate bool
updateNotified bool // Only pop up the notification once - TODO consider daily nag?
// Callbacks
callbacks commontray.Callbacks
normalIcon []byte
updateIcon []byte
}
normalIcon []byte
updateIcon []byte
var wt winTray
// TODO clean up exit handling
quitting bool
func (t *winTray) GetCallbacks() commontray.Callbacks {
return t.callbacks
app AppCallbacks
}
func InitTray(icon, updateIcon []byte) (*winTray, error) {
wt.callbacks.Quit = make(chan struct{})
wt.callbacks.Update = make(chan struct{})
wt.callbacks.ShowLogs = make(chan struct{})
wt.callbacks.DoFirstUse = make(chan struct{})
var wt winTray
func InitTray(icon, updateIcon []byte, app AppCallbacks) (*winTray, error) {
wt.normalIcon = icon
wt.updateIcon = updateIcon
wt.app = app
if err := wt.initInstance(); err != nil {
return nil, fmt.Errorf("Unable to init instance: %w\n", err)
}
......@@ -89,12 +124,17 @@ func InitTray(icon, updateIcon []byte) (*winTray, error) {
return nil, fmt.Errorf("Unable to set icon: %w", err)
}
h, err := wt.loadIconFrom(iconFilePath)
if err != nil {
return nil, fmt.Errorf("Unable to set default icon: %w", err)
}
wt.defaultIcon = h
return &wt, wt.initMenus()
}
func (t *winTray) initInstance() error {
const (
className = "OllamaClass"
windowName = ""
)
......@@ -135,7 +175,7 @@ func (t *winTray) initInstance() error {
}
t.cursor = windows.Handle(cursorHandle)
classNamePtr, err := windows.UTF16PtrFromString(className)
classNamePtr, err := windows.UTF16PtrFromString(ClassName)
if err != nil {
return err
}
......@@ -435,7 +475,7 @@ func (t *winTray) setIcon(src string) error {
defer t.muNID.Unlock()
t.nid.Icon = h
t.nid.Flags |= NIF_ICON | NIF_TIP
if toolTipUTF16, err := syscall.UTF16FromString(commontray.ToolTip); err == nil {
if toolTipUTF16, err := syscall.UTF16FromString("Ollama"); err == nil {
copy(t.nid.Tip[:], toolTipUTF16)
} else {
return err
......@@ -476,6 +516,10 @@ func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
return h, nil
}
func (t *winTray) GetIconHandle() windows.Handle {
return t.defaultIcon
}
func (t *winTray) DisplayFirstUseNotification() error {
t.muNID.Lock()
defer t.muNID.Unlock()
......
......@@ -18,6 +18,7 @@ var (
pDefWindowProc = u32.NewProc("DefWindowProcW")
pDestroyWindow = u32.NewProc("DestroyWindow")
pDispatchMessage = u32.NewProc("DispatchMessageW")
pFindWindow = u32.NewProc("FindWindowW")
pGetCursorPos = u32.NewProc("GetCursorPos")
pGetMessage = u32.NewProc("GetMessageW")
pGetModuleHandle = k32.NewProc("GetModuleHandleW")
......@@ -29,6 +30,7 @@ var (
pPostQuitMessage = u32.NewProc("PostQuitMessage")
pRegisterClass = u32.NewProc("RegisterClassExW")
pRegisterWindowMessage = u32.NewProc("RegisterWindowMessageW")
pSendMessage = u32.NewProc("SendMessageW")
pSetForegroundWindow = u32.NewProc("SetForegroundWindow")
pSetMenuInfo = u32.NewProc("SetMenuInfo")
pSetMenuItemInfo = u32.NewProc("SetMenuItemInfoW")
......@@ -69,7 +71,16 @@ const (
TPM_LEFTALIGN = 0x0000
TPM_RIGHTBUTTON = 0x0002
WM_CLOSE = 0x0010
WM_RBUTTONUP = 0x0205
WM_LBUTTONUP = 0x0202
WM_COMMAND = 0x0111
WM_ENDSESSION = 0x0016
WM_QUIT = 0x0012
WM_DESTROY = 0x0002
WM_MOUSEMOVE = 0x0200
WM_LBUTTONDOWN = 0x0201
WM_USER = 0x0400
WM_COPYDATA = 0x004A
WS_CAPTION = 0x00C00000
WS_MAXIMIZEBOX = 0x00010000
WS_MINIMIZEBOX = 0x00020000
......@@ -77,6 +88,8 @@ const (
WS_OVERLAPPEDWINDOW = WS_OVERLAPPED | WS_CAPTION | WS_SYSMENU | WS_THICKFRAME | WS_MINIMIZEBOX | WS_MAXIMIZEBOX
WS_SYSMENU = 0x00080000
WS_THICKFRAME = 0x00040000
MB_OK = 0x00000000
MB_ICONINFORMATION = 0x00000040
)
// Not sure if this is actually needed on windows
......@@ -89,3 +102,11 @@ func init() {
type point struct {
X, Y int32
}
// COPYDATASTRUCT contains data to be passed to another application by WM_COPYDATA
// https://docs.microsoft.com/en-us/windows/win32/api/winuser/ns-winuser-copydatastruct
type COPYDATASTRUCT struct {
DwData uintptr
CbData uint32
LpData uintptr
}
package discover
import (
"log/slog"
"os"
"testing"
"github.com/ollama/ollama/app/lifecycle"
)
func init() {
lifecycle.InitLogging()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
}
func TestFilterOverlapByLibrary(t *testing.T) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment